diff --git a/server/Justfile b/server/Justfile index cf80594..78ac928 100644 --- a/server/Justfile +++ b/server/Justfile @@ -106,7 +106,8 @@ test-integration: @TEST=1 uv run pytest tests/integration -v --tb=short -x # Run integration tests with PG: ensure DB running → wipe → create → migrate → test → wipe -test-integration-pg: +# Optionally narrow to a pytest target, e.g. `just test-integration-pg tests/integration/test_foo.py` +test-integration-pg target="tests/integration": POSTGRES_PORT={{TEST_PG_PORT}} just --justfile ../Justfile db-up @just test-db-drop @just test-db-create @@ -116,7 +117,7 @@ test-integration-pg: OSA_DATABASE__URL="{{TEST_DB_URL}}" \ OSA_AUTH__JWT__SECRET="test-secret-for-integration-tests-minimum-32-chars" \ TEST=1 \ - uv run pytest tests/integration -v --tb=short -x; \ + uv run pytest "{{target}}" -v --tb=short -x; \ just test-db-drop # === Dependencies === diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 6d175ae..63925cf 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -8,6 +8,7 @@ from dishka import Provider as DishkaProvider from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from slowapi.errors import RateLimitExceeded from sqlalchemy.ext.asyncio import AsyncEngine from osa.application.api.v1.errors import map_osa_error @@ -16,17 +17,16 @@ auth, conventions, depositions, - discovery, events, ingestions, health, ontologies, - records, schemas, - search, stats, validation, ) +from osa.application.api.v1.routes import data as data_routes +from osa.application.api.v1.routes.data._limiter import limiter from osa.application.di import create_container from osa.config import DEV_JWT_SECRET, Config from osa.domain.shared.authorization.startup import validate_all_handlers @@ -194,8 +194,6 @@ def create_app( app_instance.include_router(admin.router, prefix="/api/v1") app_instance.include_router(auth.router, prefix="/api/v1") app_instance.include_router(events.router, prefix="/api/v1") - app_instance.include_router(records.router, prefix="/api/v1") - app_instance.include_router(search.router, prefix="/api/v1") app_instance.include_router(stats.router, prefix="/api/v1") app_instance.include_router(ontologies.router, prefix="/api/v1") app_instance.include_router(schemas.router, prefix="/api/v1") @@ -203,7 +201,18 @@ def create_app( app_instance.include_router(depositions.router, prefix="/api/v1") app_instance.include_router(ingestions.router, prefix="/api/v1") app_instance.include_router(validation.router, prefix="/api/v1") - app_instance.include_router(discovery.router, prefix="/api/v1") + app_instance.include_router(data_routes.router, prefix="/api/v1") + + # POST /data/* rate limiting (slowapi). The shared limiter is attached to + # app state and its 429 handler registered; GET routes are not limited. + app_instance.state.limiter = limiter + + @app_instance.exception_handler(RateLimitExceeded) + async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + return JSONResponse( + status_code=429, + content={"code": "rate_limited", "message": "POST rate limit exceeded (10/min/IP)."}, + ) # Global OSA error handler - maps domain and infrastructure errors to HTTP responses @app_instance.exception_handler(OSAError) diff --git a/server/osa/application/api/v1/errors.py b/server/osa/application/api/v1/errors.py index 8a9d4cf..dd4c790 100644 --- a/server/osa/application/api/v1/errors.py +++ b/server/osa/application/api/v1/errors.py @@ -15,6 +15,7 @@ InvalidStateError, NotFoundError, OSAError, + ReservedNameError, ValidationError, ) @@ -24,6 +25,7 @@ InvalidStateError: 409, ConflictError: 409, AuthorizationError: 403, + ReservedNameError: 400, } diff --git a/server/osa/application/api/v1/routes/data/__init__.py b/server/osa/application/api/v1/routes/data/__init__.py new file mode 100644 index 0000000..4234ae8 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/__init__.py @@ -0,0 +1,47 @@ +"""Unified ``/data/`` read surface router. + +Subroutes are registered by the user-story phases: +- catalog + manifest (``GET /data``, ``GET /data/{schema}``) — US3 +- single record by ID (``GET /data/records/{id}``) — US4 +- records table matrix (``/data/{schema}/records*``) — US1/US2 via the factory +- feature table matrix (``/data/{schema}/{feature}*``) — US5 via the factory +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import DishkaRoute +from fastapi import APIRouter + +from osa.application.api.v1.routes.data import ( + catalog, + features_table, + records, + records_table, +) +from osa.domain.data.model.catalog import NodeCatalog + +router = APIRouter(prefix="/data", tags=["data"], route_class=DishkaRoute) + +# ``GET /data`` (empty sub-path) is registered directly on the prefixed router. +router.add_api_route( + "", + catalog.get_node_catalog, + methods=["GET"], + operation_id="data_get_node_catalog", + response_model=NodeCatalog, +) + +# Table-shaped routes (records + feature tables) are registered via the factory +# onto a DishkaRoute-enabled subrouter. Records is registered before the feature +# table's ``/{schema}/{feature}`` catch-all so ``/{schema}/records`` matches the +# records routes rather than being captured as a feature named "records". +tables_router = APIRouter(route_class=DishkaRoute) +records_table.register(tables_router) +features_table.register(tables_router) + +# Order matters: literal ``/records/{id}`` and the ``/{schema}/records*`` table +# routes must precede the manifest's ``/{schema}`` catch-all so a record fetch +# or table read isn't captured as a schema manifest lookup. +router.include_router(records.router) +router.include_router(tables_router) +router.include_router(catalog.manifest_router) diff --git a/server/osa/application/api/v1/routes/data/_limiter.py b/server/osa/application/api/v1/routes/data/_limiter.py new file mode 100644 index 0000000..fe924d8 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_limiter.py @@ -0,0 +1,20 @@ +"""Shared slowapi limiter for ``/data/`` POST routes (research §5). + +POST routes accept structured ``FilterExpr`` bodies that can express expensive +queries, so they are rate-limited per source IP. GET routes are intentionally +unlimited — stable GET URLs are exactly what we want CDNs to cache and +consumers to curl on a cron. + +Default: 10 requests/minute/IP on POST routes. The limiter instance is shared +(module singleton) and registered on the FastAPI app at startup; route handlers +apply it via ``@limiter.limit(POST_RATE_LIMIT)``. +""" + +from __future__ import annotations + +from slowapi import Limiter +from slowapi.util import get_remote_address + +POST_RATE_LIMIT = "10/minute" + +limiter = Limiter(key_func=get_remote_address) diff --git a/server/osa/application/api/v1/routes/data/_params.py b/server/osa/application/api/v1/routes/data/_params.py new file mode 100644 index 0000000..55ddbf0 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_params.py @@ -0,0 +1,44 @@ +"""Shared request parsing for table routes — sort spec + filter body.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from osa.domain.data.model.filter import FilterExpr +from osa.domain.data.model.query_plan import SortDirection, SortSpec +from osa.domain.shared.error import ValidationError + + +class FilterRequestBody(BaseModel): + """POST body shared by every table format (records + feature).""" + + filter: FilterExpr | None = None + cursor: str | None = None + # Unbounded at the edge: the table-read handler clamps to + # [1, DataConfig.max_page_limit] — over-large requests get the max page, not a 422. + limit: int = 50 + sort: str | None = None + + +def parse_sort(raw: str | None) -> list[SortSpec]: + """Parse ``col[:asc|:desc],col2[:asc|:desc]`` → SortSpec list (empty if None).""" + if not raw: + return [] + specs: list[SortSpec] = [] + for part in raw.split(","): + token = part.strip() + if not token: + continue + if ":" in token: + column, direction = token.split(":", 1) + try: + dir_enum = SortDirection(direction.strip().lower()) + except ValueError as exc: + raise ValidationError( + f"Invalid sort direction in {token!r}; expected 'asc' or 'desc'.", + field="sort", + ) from exc + else: + column, dir_enum = token, SortDirection.ASC + specs.append(SortSpec(column=column.strip(), direction=dir_enum)) + return specs diff --git a/server/osa/application/api/v1/routes/data/_streaming.py b/server/osa/application/api/v1/routes/data/_streaming.py new file mode 100644 index 0000000..09bfca7 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_streaming.py @@ -0,0 +1,97 @@ +"""Response assembly for table reads — streaming and paginated paths. + +Two shapes share one entry point :func:`build_table_response`: + +* **Streaming** (``.csv`` / ``.csv.gz``): pre-flight pulls the first row inside + a try (research §4). A parse/validation/planner error raised before the first + row propagates to the route → mapped to a 4xx/404 *before any bytes*. On + success the first row is chained back in and the whole iterator is streamed. + +* **Paginated** (JSON): consume up to ``limit + 1`` rows; if a ``limit+1``-th + row exists there's a next page, so derive ``next_cursor`` from the last + returned row's ``(sort_value, srn)`` pair. The bounded page (``limit`` ≤ 1000) + is then rendered by the JSON serializer. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any + +from fastapi.responses import StreamingResponse + +from osa.application.api.v1.routes.data.formats import DataResponseFormat +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import QueryPlan + + +async def build_table_response( + rows: AsyncIterator[Mapping[str, Any]], + fmt: DataResponseFormat, + columns: Sequence[ColumnSpec], + plan: QueryPlan, +) -> StreamingResponse: + if fmt.paginated: + return await _paginated_response(rows, fmt, columns, plan) + return await _streaming_response(rows, fmt, columns) + + +async def _streaming_response( + rows: AsyncIterator[Mapping[str, Any]], + fmt: DataResponseFormat, + columns: Sequence[ColumnSpec], +) -> StreamingResponse: + iterator = rows.__aiter__() + # Pre-flight: surface setup/validation errors before the first byte. + try: + first = await iterator.__anext__() + empty = False + except StopAsyncIteration: + empty = True + + async def chained() -> AsyncIterator[Mapping[str, Any]]: + if not empty: + yield first + async for row in iterator: + yield row + + serializer = fmt.make_serializer() + return StreamingResponse( + serializer.stream(chained(), columns), + media_type=fmt.media_type, + ) + + +async def _paginated_response( + rows: AsyncIterator[Mapping[str, Any]], + fmt: DataResponseFormat, + columns: Sequence[ColumnSpec], + plan: QueryPlan, +) -> StreamingResponse: + limit = plan.pagination.limit + page: list[Mapping[str, Any]] = [] + has_more = False + async for row in rows: + if len(page) == limit: + has_more = True + break + page.append(row) + + next_cursor = _next_cursor(page, plan) if has_more else None + + async def page_iter() -> AsyncIterator[Mapping[str, Any]]: + for row in page: + yield row + + serializer = fmt.make_serializer() + return StreamingResponse( + serializer.stream(page_iter(), columns, next_cursor=next_cursor), + media_type=fmt.media_type, + ) + + +def _next_cursor(page: list[Mapping[str, Any]], plan: QueryPlan) -> str | None: + # Tiebreak selection and sort=id aliasing live on plan.keyset — the same + # object the store builds its ORDER BY / after-condition from, so the + # encode side cannot drift from the decode side. + return plan.keyset.cursor_from_row(page[-1]) if page else None diff --git a/server/osa/application/api/v1/routes/data/catalog.py b/server/osa/application/api/v1/routes/data/catalog.py new file mode 100644 index 0000000..b8007f9 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/catalog.py @@ -0,0 +1,41 @@ +"""Catalog & manifest routes — ``GET /data`` and ``GET /data/{schema}``. + +JSON-only (no format suffix). Reserved schema names (``records``, ``datasets``) +and unknown schemas surface as 404 via the handler's ``NotFoundError``. + +The node-catalog handler is registered directly on the prefixed ``/data`` +router (an empty sub-path can't be ``include_router``-ed); the manifest handler +lives on ``manifest_router`` so its ``/{schema}`` catch-all can be ordered +after the literal ``/records/{id}`` route. +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import SchemaManifest +from osa.domain.data.query.catalog import ( + GetNodeCatalog, + GetNodeCatalogHandler, + GetSchemaManifest, + GetSchemaManifestHandler, +) + +manifest_router = APIRouter(route_class=DishkaRoute) + + +async def get_node_catalog(handler: FromDishka[GetNodeCatalogHandler]) -> NodeCatalog: + """List schemas hosted at this node.""" + return await handler.run(GetNodeCatalog()) + + +@manifest_router.get( + "/{schema}", operation_id="data_get_schema_manifest", response_model=SchemaManifest +) +async def get_schema_manifest( + schema: str, handler: FromDishka[GetSchemaManifestHandler] +) -> SchemaManifest: + """Machine-readable manifest for a schema (`` or `@`).""" + return await handler.run(GetSchemaManifest(schema=schema)) diff --git a/server/osa/application/api/v1/routes/data/features_table.py b/server/osa/application/api/v1/routes/data/features_table.py new file mode 100644 index 0000000..a23bc57 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/features_table.py @@ -0,0 +1,80 @@ +"""Feature-table routes — ``/data/{schema}/{feature}[.csv|.csv.gz]`` (US5). + +Identical shape to the records table — same factory, same streaming/pagination +engine — but the path carries a ``{feature}`` segment and the handler builds a +``TableKind.FEATURE`` plan. An unknown feature 404s before any bytes (T090). +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import FromDishka +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse + +from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter +from osa.application.api.v1.routes.data._params import FilterRequestBody, parse_sort +from osa.application.api.v1.routes.data._streaming import build_table_response +from osa.application.api.v1.routes.data.formats import DataResponseFormat +from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.domain.data.query.read_table import ReadFeatureTable, ReadFeatureTableHandler + + +def _make_get_endpoint(fmt: DataResponseFormat): + async def endpoint( + schema: str, + feature: str, + handler: FromDishka[ReadFeatureTableHandler], + cursor: str | None = None, + limit: int = 50, + sort: str | None = None, + ) -> StreamingResponse: + result = await handler.run( + ReadFeatureTable( + schema=schema, + feature=feature, + cursor=cursor, + limit=limit, + sort=parse_sort(sort), + timeout=fmt.timeout, + ) + ) + return await build_table_response(result.rows, fmt, result.columns, result.plan) + + return endpoint + + +def _make_post_endpoint(fmt: DataResponseFormat): + async def endpoint( + request: Request, + schema: str, + feature: str, + body: FilterRequestBody, + handler: FromDishka[ReadFeatureTableHandler], + ) -> StreamingResponse: + result = await handler.run( + ReadFeatureTable( + schema=schema, + feature=feature, + filter=body.filter, + cursor=body.cursor, + limit=body.limit, + sort=parse_sort(body.sort), + timeout=fmt.timeout, + ) + ) + return await build_table_response(result.rows, fmt, result.columns, result.plan) + + # Unique name before the limiter — see records_table._make_post_endpoint. + endpoint.__name__ = f"feature_post_{format_key(fmt)}" + endpoint.__qualname__ = endpoint.__name__ + return limiter.limit(POST_RATE_LIMIT)(endpoint) + + +def register(router: APIRouter) -> None: + register_table_routes( + router, + "/{schema}/{feature}", + _make_get_endpoint, + _make_post_endpoint, + "feature", + ) diff --git a/server/osa/application/api/v1/routes/data/formats.py b/server/osa/application/api/v1/routes/data/formats.py new file mode 100644 index 0000000..311d6ba --- /dev/null +++ b/server/osa/application/api/v1/routes/data/formats.py @@ -0,0 +1,60 @@ +"""Response-format registry for the ``/data/`` surface. + +Each :class:`DataResponseFormat` ties a URL suffix to its serializer, +pagination semantics, and per-route statement-timeout budget (research §7). +Adding a format (e.g. ``ndjson``, ``parquet``) is a one-line append plus one +serializer class — the route factory picks it up automatically. + +This is presentation configuration, so it lives with the routes: media types +and Postgres execution budgets are HTTP/adapter concerns, not domain model. +The domain's :class:`~osa.domain.data.model.query_plan.QueryPlan` stays a pure +query description. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta + +from osa.application.api.v1.routes.data.serializers.csv import CsvSerializer +from osa.application.api.v1.routes.data.serializers.csv_gzip import CsvGzipSerializer +from osa.application.api.v1.routes.data.serializers.json import JsonSerializer +from osa.application.api.v1.routes.data.serializers.protocol import Serializer + + +@dataclass(frozen=True) +class DataResponseFormat: + serializer_cls: type[Serializer] + paginated: bool # True → JSON envelope + cursor; False → unbounded stream + suffix: str # "" (json), "csv", "csv.gz" + timeout: timedelta # statement budget: short for paginated, long for dumps + + @property + def media_type(self) -> str: + # The serializer is the single owner of its media type. + return self.serializer_cls.media_type + + def make_serializer(self) -> Serializer: + return self.serializer_cls() + + +FORMATS: tuple[DataResponseFormat, ...] = ( + DataResponseFormat( + serializer_cls=JsonSerializer, + paginated=True, + suffix="", + timeout=timedelta(seconds=30), + ), + DataResponseFormat( + serializer_cls=CsvSerializer, + paginated=False, + suffix="csv", + timeout=timedelta(minutes=30), + ), + DataResponseFormat( + serializer_cls=CsvGzipSerializer, + paginated=False, + suffix="csv.gz", + timeout=timedelta(minutes=30), + ), +) diff --git a/server/osa/application/api/v1/routes/data/models.py b/server/osa/application/api/v1/routes/data/models.py new file mode 100644 index 0000000..7b7aaea --- /dev/null +++ b/server/osa/application/api/v1/routes/data/models.py @@ -0,0 +1,32 @@ +"""Shared Pydantic response models for the ``/data/`` routes.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from osa.domain.data.model.record_summary import RecordSummary + + +class RecordResponse(BaseModel): + """Single-record response — carries BOTH the bare ``id`` and full ``srn``.""" + + id: str + srn: str + schema_id: str + version: int + metadata: dict[str, Any] + created_at: datetime + + @classmethod + def from_summary(cls, summary: RecordSummary) -> "RecordResponse": + return cls( + id=str(summary.id), + srn=str(summary.srn), + schema_id=summary.schema_id.render(), + version=summary.version, + metadata=summary.metadata, + created_at=summary.created_at, + ) diff --git a/server/osa/application/api/v1/routes/data/records.py b/server/osa/application/api/v1/routes/data/records.py new file mode 100644 index 0000000..8790647 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/records.py @@ -0,0 +1,28 @@ +"""Single-record-by-ID route — ``GET /data/records/{id}[@{version}]`` (US4). + +Resolves a published record by its bare internal ID; the server finds the +schema via primary-key lookup. The response carries both ``id`` and ``srn``. +A bare ``GET /data/records`` (no id) is not defined — that slot is reserved for +the deferred cross-schema bulk read and 404s naturally. +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter + +from osa.application.api.v1.routes.data.models import RecordResponse +from osa.domain.data.query.catalog import GetDataRecord, GetDataRecordHandler +from osa.domain.shared.model.ids import RecordRef + +router = APIRouter(route_class=DishkaRoute) + + +@router.get( + "/records/{record_id}", operation_id="data_get_record_by_id", response_model=RecordResponse +) +async def get_record_by_id( + record_id: str, handler: FromDishka[GetDataRecordHandler] +) -> RecordResponse: + summary = await handler.run(GetDataRecord(ref=RecordRef.parse(record_id))) + return RecordResponse.from_summary(summary) diff --git a/server/osa/application/api/v1/routes/data/records_table.py b/server/osa/application/api/v1/routes/data/records_table.py new file mode 100644 index 0000000..73318e7 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/records_table.py @@ -0,0 +1,82 @@ +"""Records-table routes — ``/data/{schema}/records[.csv|.csv.gz]`` (US1 + US2). + +Endpoint *builders* capture a :class:`DataResponseFormat` by closure; the +generic :func:`register_table_routes` factory registers the GET/POST × format +matrix from them. GET streams (or paginates JSON) with no body; POST takes a +``FilterExpr`` body and is rate-limited per IP. Everything between HTTP and +the row stream — table resolution, plan construction, limit clamping — lives +in :class:`ReadRecordsTableHandler`. +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import FromDishka +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse + +from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter +from osa.application.api.v1.routes.data._params import FilterRequestBody, parse_sort +from osa.application.api.v1.routes.data._streaming import build_table_response +from osa.application.api.v1.routes.data.formats import DataResponseFormat +from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.domain.data.query.read_table import ReadRecordsTable, ReadRecordsTableHandler + + +def _make_get_endpoint(fmt: DataResponseFormat): + async def endpoint( + schema: str, + handler: FromDishka[ReadRecordsTableHandler], + cursor: str | None = None, + limit: int = 50, + sort: str | None = None, + ) -> StreamingResponse: + result = await handler.run( + ReadRecordsTable( + schema=schema, + cursor=cursor, + limit=limit, + sort=parse_sort(sort), + timeout=fmt.timeout, + ) + ) + return await build_table_response(result.rows, fmt, result.columns, result.plan) + + return endpoint + + +def _make_post_endpoint(fmt: DataResponseFormat): + async def endpoint( + request: Request, + schema: str, + body: FilterRequestBody, + handler: FromDishka[ReadRecordsTableHandler], + ) -> StreamingResponse: + result = await handler.run( + ReadRecordsTable( + schema=schema, + filter=body.filter, + cursor=body.cursor, + limit=body.limit, + sort=parse_sort(body.sort), + timeout=fmt.timeout, + ) + ) + return await build_table_response(result.rows, fmt, result.columns, result.plan) + + # slowapi scopes a rate limit by the decorated function's ``module.__name__`` + # captured at decoration time. These builder closures are all named + # ``endpoint``, so name each uniquely BEFORE applying the limiter, or every + # POST route would collapse into a single shared limit bucket. + endpoint.__name__ = f"records_post_{format_key(fmt)}" + endpoint.__qualname__ = endpoint.__name__ + return limiter.limit(POST_RATE_LIMIT)(endpoint) + + +def register(router: APIRouter) -> None: + register_table_routes( + router, + "/{schema}/records", + _make_get_endpoint, + _make_post_endpoint, + "records", + ) diff --git a/server/osa/domain/discovery/model/__init__.py b/server/osa/application/api/v1/routes/data/serializers/__init__.py similarity index 100% rename from server/osa/domain/discovery/model/__init__.py rename to server/osa/application/api/v1/routes/data/serializers/__init__.py diff --git a/server/osa/application/api/v1/routes/data/serializers/csv.py b/server/osa/application/api/v1/routes/data/serializers/csv.py new file mode 100644 index 0000000..710fa23 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/serializers/csv.py @@ -0,0 +1,61 @@ +"""CSV serializer — header row from columns, then one row per record. + +Streams incrementally via a reusable :class:`_RowEncoder` that holds a single +``io.StringIO`` + ``csv.writer``, so only one row is ever materialised at a +time. An empty result still yields the header row followed by EOF. Quoting uses +``csv.QUOTE_MINIMAL``. +""" + +from __future__ import annotations + +import csv +import io +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar + + +from osa.domain.data.model.manifest import ColumnSpec + + +class _RowEncoder: + """Encodes one CSV row at a time, reusing a single buffer + writer. + + The writer's type is inferred from the ``csv.writer(...)`` assignment, so + no untyped parameter is threaded through the serializer. + """ + + def __init__(self) -> None: + self._buf = io.StringIO() + self._writer = csv.writer(self._buf, quoting=csv.QUOTE_MINIMAL) + + def encode(self, values: Sequence[Any]) -> bytes: + self._buf.seek(0) + self._buf.truncate(0) + self._writer.writerow(values) + return self._buf.getvalue().encode() + + +class CsvSerializer: + media_type: ClassVar[str] = "text/csv" + + async def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + names = [col.name for col in columns] + encoder = _RowEncoder() + + yield encoder.encode(names) + + async for row in rows: + yield encoder.encode([_stringify(row.get(name)) for name in names]) + + +def _stringify(value: Any) -> Any: + """Render non-scalar values deterministically; let csv handle scalars/None.""" + if value is None or isinstance(value, (str, int, float, bool)): + return value + return str(value) diff --git a/server/osa/application/api/v1/routes/data/serializers/csv_gzip.py b/server/osa/application/api/v1/routes/data/serializers/csv_gzip.py new file mode 100644 index 0000000..4f159fe --- /dev/null +++ b/server/osa/application/api/v1/routes/data/serializers/csv_gzip.py @@ -0,0 +1,41 @@ +"""Gzip-while-streaming CSV serializer (research §1). + +Wraps :class:`CsvSerializer` through ``zlib.compressobj(level=6, +wbits=MAX_WBITS|16)``. The ``MAX_WBITS|16`` flag selects gzip-stream mode +(proper gzip header + trailer) so the byte stream is exactly what ``gunzip`` +and HTTP gzip clients expect. Memory footprint is the 32KB DEFLATE window plus +one CSV row, regardless of result size — the basis of the SC-001 bounded-memory +target. +""" + +from __future__ import annotations + +import zlib +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar + +from osa.domain.data.model.manifest import ColumnSpec +from osa.application.api.v1.routes.data.serializers.csv import CsvSerializer + + +class CsvGzipSerializer: + media_type: ClassVar[str] = "application/gzip" + + def __init__(self) -> None: + self._csv = CsvSerializer() + + async def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + compressor = zlib.compressobj(level=6, wbits=zlib.MAX_WBITS | 16) + async for chunk in self._csv.stream(rows, columns): + compressed = compressor.compress(chunk) + if compressed: + yield compressed + tail = compressor.flush(zlib.Z_FINISH) + if tail: + yield tail diff --git a/server/osa/application/api/v1/routes/data/serializers/json.py b/server/osa/application/api/v1/routes/data/serializers/json.py new file mode 100644 index 0000000..43ff0fb --- /dev/null +++ b/server/osa/application/api/v1/routes/data/serializers/json.py @@ -0,0 +1,38 @@ +"""JSON serializer — paginated envelope ``{"rows": [...], "next_cursor": ...}``. + +Built row-by-row so the whole page is never held twice in memory. The +``next_cursor`` is precomputed by the query service and passed in; an empty +result yields ``{"rows": [], "next_cursor": null}`` with HTTP 200. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar + +from osa.domain.data.model.manifest import ColumnSpec + + +class JsonSerializer: + media_type: ClassVar[str] = "application/json" + + async def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + yield b'{"rows":[' + first = True + async for row in rows: + projected = {col.name: row.get(col.name) for col in columns} + chunk = json.dumps(projected, default=str).encode() + if first: + first = False + else: + chunk = b"," + chunk + yield chunk + cursor_json = json.dumps(next_cursor).encode() + yield b'],"next_cursor":' + cursor_json + b"}" diff --git a/server/osa/application/api/v1/routes/data/serializers/protocol.py b/server/osa/application/api/v1/routes/data/serializers/protocol.py new file mode 100644 index 0000000..bfcffcf --- /dev/null +++ b/server/osa/application/api/v1/routes/data/serializers/protocol.py @@ -0,0 +1,34 @@ +"""Serializer protocol — rows in, bytes out. + +Serializers are stateless and have no I/O dependencies beyond stdlib (``json``, +``csv``, ``zlib``). They consume an async iterator of already-projected rows +(column→value mappings) and the column schema that fixes wire order, and yield +response bytes incrementally so streaming formats stay memory-bounded. + +Note on ``next_cursor``: cursor derivation lives in the query service (matching +the proven discovery engine), not in the serializer. Paginated serializers +(JSON) receive the precomputed ``next_cursor`` to embed in the envelope; +streaming serializers (CSV, CSV.gz) ignore it. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar, Protocol, runtime_checkable + +from osa.domain.data.model.manifest import ColumnSpec + + +@runtime_checkable +class Serializer(Protocol): + media_type: ClassVar[str] + + def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + """Render ``rows`` as response bytes, yielded incrementally.""" + ... diff --git a/server/osa/application/api/v1/routes/data/tables.py b/server/osa/application/api/v1/routes/data/tables.py new file mode 100644 index 0000000..42c41a5 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/tables.py @@ -0,0 +1,89 @@ +"""Metaprogrammed table-route factory. + +One call to :func:`register_table_routes` registers the full GET/POST × format +matrix (6 routes) for a table-shaped resource — the records table and every +feature table share this exact shape. The caller supplies endpoint builders +(``make_get_endpoint`` / ``make_post_endpoint``) that capture the +:class:`DataResponseFormat` by closure and declare the FastAPI-visible +signature (path params + injected dependencies), because path params differ +between resources (``/{schema}/records`` vs ``/{schema}/{feature}``). + +Operation IDs are stable and match the OpenAPI contract: +``{resource}_{method}_{format_key}`` where ``format_key`` is ``json`` / ``csv`` +/ ``csv_gz``. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from fastapi import APIRouter + +from osa.application.api.v1.routes.data.formats import FORMATS, DataResponseFormat + +EndpointBuilder = Callable[[DataResponseFormat], Callable[..., Any]] + + +def format_key(fmt: DataResponseFormat) -> str: + """``""`` → ``json``; ``csv`` → ``csv``; ``csv.gz`` → ``csv_gz``.""" + return "json" if fmt.suffix == "" else fmt.suffix.replace(".", "_") + + +def _existing_operation_ids(router: APIRouter) -> set[str]: + """Operation IDs already registered on *router* (from prior factory calls).""" + return {op for r in router.routes if (op := getattr(r, "operation_id", None))} + + +def path_for(base_path: str, fmt: DataResponseFormat) -> str: + return base_path if fmt.suffix == "" else f"{base_path}.{fmt.suffix}" + + +def register_table_routes( + router: APIRouter, + base_path: str, + make_get_endpoint: EndpointBuilder, + make_post_endpoint: EndpointBuilder, + resource_name: str, +) -> None: + """Register GET + POST routes for every format under ``base_path``. + + Formats are registered longest-suffix-first so the bare (JSON) path is added + last. This matters when ``base_path`` ends in a path parameter (e.g. the + feature table's ``/{schema}/{feature}``): Starlette's ``{feature}`` matches + ``[^/]+``, which includes dots, so a bare ``/{schema}/{feature}`` registered + first would greedily capture ``feature.csv.gz`` as the feature name and the + suffixed routes would never match. Resources with a literal segment (records) + are unaffected by the ordering. + """ + # Seed with op IDs already on this router so a second factory call for a + # different resource (records + feature share one router) can't silently mint + # a colliding id — operation IDs must be globally unique for SDK codegen. + seen_ids = _existing_operation_ids(router) + for fmt in sorted(FORMATS, key=lambda f: len(f.suffix), reverse=True): + key = format_key(fmt) + path = path_for(base_path, fmt) + get_id = f"{resource_name}_get_{key}" + post_id = f"{resource_name}_post_{key}" + for op_id in (get_id, post_id): + if op_id in seen_ids: + raise ValueError( + f"Duplicate operation_id generated by the table factory: {op_id!r}. " + f"Resource name {resource_name!r} collides with a route already on this " + "router — operation IDs must be globally unique for SDK codegen." + ) + seen_ids.add(op_id) + router.add_api_route( + path, + make_get_endpoint(fmt), + methods=["GET"], + operation_id=get_id, + response_model=None, + ) + router.add_api_route( + path, + make_post_endpoint(fmt), + methods=["POST"], + operation_id=post_id, + response_model=None, + ) diff --git a/server/osa/application/api/v1/routes/discovery.py b/server/osa/application/api/v1/routes/discovery.py deleted file mode 100644 index fb4295a..0000000 --- a/server/osa/application/api/v1/routes/discovery.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Discovery API routes — search and filter records and features.""" - -from typing import Any - -from dishka.integrations.fastapi import DishkaRoute, FromDishka -from fastapi import APIRouter -from pydantic import BaseModel, Field - -from osa.domain.discovery.model.value import FilterExpr, SortOrder -from osa.domain.discovery.query.get_feature_catalog import ( - GetFeatureCatalog, - GetFeatureCatalogHandler, - GetFeatureCatalogResult, -) -from osa.domain.discovery.query.search_features import ( - SearchFeatures, - SearchFeaturesHandler, - SearchFeaturesResult, -) -from osa.domain.discovery.query.search_records import ( - SearchRecords, - SearchRecordsHandler, - SearchRecordsResult, -) -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import ConventionSRN, SchemaId - -router = APIRouter( - prefix="/discovery", - tags=["discovery"], - route_class=DishkaRoute, -) - - -# ── Request / Response models ── - - -class RecordSearchRequest(BaseModel): - schema: str | None = None - """Short-form schema identity: ``"@"`` (e.g. ``"pdb-structure@1.0.0"``).""" - - convention_srn: ConventionSRN | None = None - filter: FilterExpr | None = None - q: str | None = None - sort: str = "published_at" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = Field(default=20, ge=1, le=100) - - -class RecordSearchResponse(BaseModel): - results: list[dict[str, Any]] - cursor: str | None - has_more: bool - - -class FeatureCatalogResponse(BaseModel): - tables: list[dict[str, Any]] - - -class FeatureSearchRequest(BaseModel): - schema: str | None = None - """Short-form schema identity, optional. See RecordSearchRequest.schema.""" - - filter: FilterExpr | None = None - record_srn: str | None = None - sort: str = "id" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = Field(default=50, ge=1, le=100) - - -class FeatureSearchResponse(BaseModel): - rows: list[dict[str, Any]] - cursor: str | None - has_more: bool - - -def _parse_schema(value: str | None) -> SchemaId | None: - if value is None: - return None - if "@" not in value: - raise ValidationError( - f"Schema {value!r} must be fully qualified as '@' " - "(e.g. 'pdb-structure@1.0.0'). Family-level scoping " - "(id alone, resolving to the latest version across a schema family) " - "is planned but not yet supported.", - field="schema", - code="cross_scope_not_yet_supported", - ) - try: - return SchemaId.parse(value) - except ValueError as exc: - raise ValidationError(str(exc), field="schema") from exc - - -# ── Routes ── - - -@router.post("/records") -async def search_records( - body: RecordSearchRequest, - handler: FromDishka[SearchRecordsHandler], -) -> RecordSearchResponse: - """Search and filter published records.""" - result: SearchRecordsResult = await handler.run( - SearchRecords( - filter_expr=body.filter, - schema_id=_parse_schema(body.schema), - convention_srn=body.convention_srn, - q=body.q, - sort=body.sort, - order=body.order, - cursor=body.cursor, - limit=body.limit, - ) - ) - return RecordSearchResponse( - results=result.results, - cursor=result.cursor, - has_more=result.has_more, - ) - - -@router.get("/features") -async def get_feature_catalog( - handler: FromDishka[GetFeatureCatalogHandler], -) -> FeatureCatalogResponse: - """List available feature tables with column schemas and record counts.""" - result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) - return FeatureCatalogResponse(tables=result.tables) - - -@router.post("/features/{hook_name}") -async def search_features( - hook_name: str, - body: FeatureSearchRequest, - handler: FromDishka[SearchFeaturesHandler], -) -> FeatureSearchResponse: - """Query and filter rows in a specific feature table.""" - result: SearchFeaturesResult = await handler.run( - SearchFeatures( - hook_name=hook_name, - filter_expr=body.filter, - schema_id=_parse_schema(body.schema), - record_srn=body.record_srn, - sort=body.sort, - order=body.order, - cursor=body.cursor, - limit=body.limit, - ) - ) - return FeatureSearchResponse( - rows=result.rows, - cursor=result.cursor, - has_more=result.has_more, - ) diff --git a/server/osa/application/api/v1/routes/records.py b/server/osa/application/api/v1/routes/records.py deleted file mode 100644 index 4182005..0000000 --- a/server/osa/application/api/v1/routes/records.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Records API routes.""" - -from typing import Any - -from dishka.integrations.fastapi import DishkaRoute, FromDishka -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from osa.domain.record.query.get_record import GetRecord, GetRecordHandler, RecordDetail -from osa.domain.shared.model.srn import RecordSRN - -router = APIRouter( - prefix="/records", - tags=["records"], - route_class=DishkaRoute, -) - - -class RecordResponse(BaseModel): - """Single record response.""" - - record: dict[str, Any] - - -@router.get("/{srn:path}") -async def get_record( - srn: str, - handler: FromDishka[GetRecordHandler], -) -> RecordResponse: - """Get a record by SRN.""" - try: - record_srn = RecordSRN.parse(srn) - except ValueError as e: - raise HTTPException(status_code=400, detail=f"Invalid SRN: {e}") - - result: RecordDetail = await handler.run(GetRecord(srn=record_srn)) - - return RecordResponse( - record={ - "srn": str(result.srn), - "source": result.source.model_dump(), - "convention_srn": str(result.convention_srn), - "metadata": result.metadata, - "published_at": result.published_at.isoformat(), - "features": result.features, - } - ) diff --git a/server/osa/application/api/v1/routes/search.py b/server/osa/application/api/v1/routes/search.py deleted file mode 100644 index c71008e..0000000 --- a/server/osa/application/api/v1/routes/search.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Search API routes.""" - -from typing import Any - -from dishka.integrations.fastapi import DishkaRoute, FromDishka -from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel - -from osa.domain.index.model.registry import IndexRegistry -from osa.sdk.index import QueryResult - -router = APIRouter( - prefix="/search", - tags=["search"], - route_class=DishkaRoute, -) - - -class SearchResponse(BaseModel): - """Search response model.""" - - query: str - index: str - total: int - has_more: bool - results: list[dict[str, Any]] - - -@router.get("/{index_name}") -async def search_index( - index_name: str, - indexes: FromDishka[IndexRegistry], - q: str = Query(..., description="Search query"), - offset: int = Query(0, ge=0, description="Number of results to skip"), - limit: int = Query(20, ge=1, le=100, description="Maximum number of results"), -) -> SearchResponse: - """Search a specific index by name.""" - if not len(indexes): - raise HTTPException(status_code=503, detail="No search indexes configured") - - backend = indexes.get(index_name) - if not backend: - raise HTTPException( - status_code=404, - detail=f"Index '{index_name}' not found. Available: {indexes.names()}", - ) - - # Naive pagination: fetch offset+limit+1 to determine has_more, then slice - result: QueryResult = await backend.query(q, limit=offset + limit + 1) - all_hits = result.hits[offset:] # Skip first `offset` results - has_more = len(all_hits) > limit - hits = all_hits[:limit] - - return SearchResponse( - query=q, - index=index_name, - total=len(hits), - has_more=has_more, - results=[ - { - "srn": hit.srn, - "score": hit.score, - "metadata": hit.metadata, - } - for hit in hits - ], - ) - - -@router.get("/") -async def list_indexes( - indexes: FromDishka[IndexRegistry], -) -> dict[str, list[str]]: - """List available search indexes.""" - return {"indexes": indexes.names()} diff --git a/server/osa/application/api/v1/routes/stats.py b/server/osa/application/api/v1/routes/stats.py index 7fe4c2e..64ee3ec 100644 --- a/server/osa/application/api/v1/routes/stats.py +++ b/server/osa/application/api/v1/routes/stats.py @@ -4,8 +4,7 @@ from fastapi import APIRouter from pydantic import BaseModel -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.port.repository import RecordRepository +from osa.domain.record.query.get_stats import GetStats, GetStatsHandler router = APIRouter( prefix="/stats", @@ -14,38 +13,23 @@ ) -class IndexStats(BaseModel): - """Stats for a single index.""" - - name: str - count: int - healthy: bool - - class StatsResponse(BaseModel): - """System statistics response.""" + """System statistics response. + + The legacy ``indexes`` field was removed with the index domain (the unified + ``/data/`` surface replaces vector/keyword index reads). Per-schema and + per-feature-table row counts now live in each schema manifest at + ``GET /api/v1/data/{schema}``; ``data_url`` points there. + """ records: int - indexes: list[IndexStats] + data_url: str = "/api/v1/data" @router.get("") async def get_stats( - record_repo: FromDishka[RecordRepository], - indexes: FromDishka[IndexRegistry], + handler: FromDishka[GetStatsHandler], ) -> StatsResponse: """Get system statistics.""" - record_count = await record_repo.count() - - index_stats = [] - for name, backend in indexes.items(): - try: - count = await backend.count() - healthy = await backend.health() - except Exception: - count = 0 - healthy = False - - index_stats.append(IndexStats(name=name, count=count, healthy=healthy)) - - return StatsResponse(records=record_count, indexes=index_stats) + result = await handler.run(GetStats()) + return StatsResponse(records=result.records) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 5227635..a85074a 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -5,8 +5,8 @@ from osa.config import Config from osa.domain.auth.util.di import AuthProvider +from osa.domain.data.util.di import DataProvider from osa.domain.deposition.util.di import DepositionProvider -from osa.domain.discovery.util.di import DiscoveryProvider from osa.domain.feature.util.di import FeatureProvider from osa.domain.metadata.util.di import MetadataProvider from osa.domain.semantics.util.di.provider import SemanticsProvider @@ -15,9 +15,8 @@ from osa.infrastructure.auth import AuthInfraProvider from osa.infrastructure.event.di import EventProvider from osa.infrastructure.http.di import HttpProvider -from osa.infrastructure.index.di import IndexProvider from osa.infrastructure.k8s.di import RunnerProvider -from osa.infrastructure.persistence import PersistenceProvider +from osa.infrastructure.persistence.di import PersistenceProvider from osa.infrastructure.ingest.di import IngestProvider from osa.util.di.scope import Scope from osa.util.paths import OSAPaths @@ -44,7 +43,6 @@ def create_container( return make_async_container( PersistenceProvider(), RunnerProvider(), - IndexProvider(), IngestProvider(), EventProvider(extra_handlers=extra_handlers), HttpProvider(), @@ -55,7 +53,7 @@ def create_container( ValidationProvider(), AuthProvider(), AuthInfraProvider(), - DiscoveryProvider(), + DataProvider(), *extra_providers, context={Config: config, OSAPaths: paths}, scopes=Scope, # type: ignore[arg-type] # Custom scope class diff --git a/server/osa/config.py b/server/osa/config.py index c267c0a..69e15ee 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -260,6 +260,21 @@ def _normalize_pg_url(url: str) -> str: return url +class DataConfig(BaseModel): + """Bounds for the unified ``/data/`` read surface (nested in Config). + + Caps the cost of a filter tree before it is compiled to SQL (maximum tree + depth, total predicate count, distinct feature-hook joins) and the + paginated-JSON page size. Overridable via ``OSA_DATA__MAX_FILTER_DEPTH`` + etc. (FR-012, features 076/137). + """ + + max_filter_depth: int = 10 + max_predicates: int = 200 + max_feature_joins: int = 10 # distinct features. references in one filter + max_page_limit: int = 1000 # page-size ceiling; over-large requests are clamped, not 422d + + class Config(BaseSettings): # Archive identity (promoted from Server model) name: str = "Open Science Archive" @@ -283,13 +298,9 @@ class Config(BaseSettings): worker: WorkerConfig = WorkerConfig() # Background worker settings auth: AuthConfig = AuthConfig() # Defaults are dev-safe; boot check enforces prod-correctness runner: RunnerConfig = RunnerConfig() + data: DataConfig = DataConfig() # /data/ read-surface filter-tree bounds host_data_dir: str | None = None # Host path for OSA_DATA_DIR (sibling container mounts) - # Discovery filter-tree bounds (feature 076) - discovery_max_filter_depth: int = 10 - discovery_max_predicates: int = 200 - discovery_max_cross_domain_joins: int = 10 - model_config = { "env_prefix": "OSA_", "env_file": ".env", diff --git a/server/osa/domain/discovery/port/__init__.py b/server/osa/domain/data/__init__.py similarity index 100% rename from server/osa/domain/discovery/port/__init__.py rename to server/osa/domain/data/__init__.py diff --git a/server/osa/domain/discovery/query/__init__.py b/server/osa/domain/data/model/__init__.py similarity index 100% rename from server/osa/domain/discovery/query/__init__.py rename to server/osa/domain/data/model/__init__.py diff --git a/server/osa/domain/data/model/catalog.py b/server/osa/domain/data/model/catalog.py new file mode 100644 index 0000000..66f9d96 --- /dev/null +++ b/server/osa/domain/data/model/catalog.py @@ -0,0 +1,35 @@ +"""Node catalog response envelope. + +``GET /data`` returns the node's domain plus the published schemas, each with a +summary list of table resources (names + kinds only). Column-level detail comes +from the per-schema manifest at ``GET /data/{schema}``. +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from osa.domain.data.model.query_plan import TableKind + + +class TableResourceSummary(BaseModel): + """Name + kind of an addressable table resource (no columns/counts).""" + + name: str + kind: TableKind + + +class CatalogEntry(BaseModel): + """One published schema in the node catalog.""" + + id: str # short schema id + version: str # SemVer + srn: str # full schema SRN + table_resources: list[TableResourceSummary] + + +class NodeCatalog(BaseModel): + """The node's published-schema catalog. Empty ``schemas`` is valid (200).""" + + node_domain: str + schemas: list[CatalogEntry] diff --git a/server/osa/domain/discovery/model/value.py b/server/osa/domain/data/model/filter.py similarity index 55% rename from server/osa/domain/discovery/model/value.py rename to server/osa/domain/data/model/filter.py index 65549a0..02d9708 100644 --- a/server/osa/domain/discovery/model/value.py +++ b/server/osa/domain/data/model/filter.py @@ -1,28 +1,95 @@ -"""Discovery domain value objects — filters, cursors, result types. +"""Filter DSL for the ``/data/`` read surface. -Feature 076 replaces the flat ``Filter`` list with a compound ``FilterExpr`` -discriminated union (``And``/``Or``/``Not``/``Predicate``). Field references -inside predicates are typed (:class:`MetadataFieldRef` or -:class:`FeatureFieldRef`); the dotted wire form is parsed at the API boundary. +Relocated from the ``discovery`` domain (research §3 / FR-041): the +``FilterExpr`` discriminated union, its typed field references, and the +operator-compatibility tables are reused wholesale with no behaviour change. +During the discovery→data coexistence window (research §10) the ``data`` +domain keeps its own copy so it has no dependency on ``discovery`` once the +latter is deleted. + +Two kinds of field references are supported: + +- :class:`MetadataFieldRef` — resolves to a column in + ``metadata._v``. +- :class:`FeatureFieldRef` — resolves to a column in ``features.``. + +Wire format for a reference is a dotted path (``metadata.`` or +``features..``). :func:`parse_field_ref` parses the wire form +into a typed reference and validates identifier shape. """ from __future__ import annotations -import base64 -import json -from datetime import datetime +import re from enum import StrEnum from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, Field, model_validator -from osa.domain.discovery.model.refs import ( - FeatureFieldRef, - MetadataFieldRef, - parse_field_ref, -) from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.model.srn import RecordSRN + +_IDENT = re.compile(r"^[a-z][a-z0-9_]*$") + + +# ── Field references ── + + +class MetadataFieldRef(BaseModel): + path: Literal["metadata"] = "metadata" + field: str + + def dotted(self) -> str: + return f"metadata.{self.field}" + + +class FeatureFieldRef(BaseModel): + path: Literal["features"] = "features" + hook: str + column: str + + def dotted(self) -> str: + return f"features.{self.hook}.{self.column}" + + +def parse_field_ref(dotted: str) -> "FieldRef": + """Parse a dotted-path field reference into its typed form. + + Raises :class:`ValueError` when the path shape or identifier doesn't match + the documented grammar. + """ + if not isinstance(dotted, str): + raise ValueError(f"Expected dotted string, got {type(dotted).__name__}") + + parts = dotted.split(".") + if not parts: + raise ValueError(f"Empty field reference: {dotted!r}") + + head = parts[0] + if head == "metadata": + if len(parts) != 2: + raise ValueError(f"metadata.* refs must be exactly two dotted parts, got {dotted!r}") + field = parts[1] + if not _IDENT.match(field): + raise ValueError(f"Invalid metadata field identifier: {field!r}") + return MetadataFieldRef(field=field) + + if head == "features": + if len(parts) != 3: + raise ValueError(f"features.* refs must be exactly three dotted parts, got {dotted!r}") + hook, column = parts[1], parts[2] + if not _IDENT.match(hook): + raise ValueError(f"Invalid hook identifier: {hook!r}") + if not _IDENT.match(column): + raise ValueError(f"Invalid feature column identifier: {column!r}") + return FeatureFieldRef(hook=hook, column=column) + + raise ValueError( + f"Unknown field reference prefix {head!r} in {dotted!r}. " + "Expected 'metadata.' or 'features..'." + ) + + +# ── Operators ── class FilterOperator(StrEnum): @@ -37,11 +104,6 @@ class FilterOperator(StrEnum): IS_NULL = "is_null" -class SortOrder(StrEnum): - ASC = "asc" - DESC = "desc" - - FieldRef = Annotated[ Union[MetadataFieldRef, FeatureFieldRef], Field(discriminator="path"), @@ -51,6 +113,9 @@ class SortOrder(StrEnum): PredicateValue = Union[str, int, float, bool, list[str], list[float], None] +# ── Filter expression tree ── + + class Predicate(BaseModel): kind: Literal["predicate"] = "predicate" field: FieldRef @@ -88,13 +153,14 @@ class Not(BaseModel): Field(discriminator="kind"), ] -# Resolve forward references And.model_rebuild() Or.model_rebuild() Not.model_rebuild() -# Operators valid per column type for metadata/feature column validation. +# ── Operator compatibility tables ── + +# Operators valid per semantic column type for metadata/feature validation. VALID_OPERATORS: dict[FieldType, set[FilterOperator]] = { FieldType.TEXT: { FilterOperator.EQ, @@ -139,8 +205,8 @@ class Not(BaseModel): FieldType.BOOLEAN: {FilterOperator.EQ, FilterOperator.IS_NULL}, } -# Operators valid against raw JSON-schema primitive types (used for feature columns -# whose Column.json_type is a JSON Schema primitive rather than a semantic FieldType). +# Operators valid against raw JSON-schema primitive types (feature columns +# whose Column.json_type is a JSON Schema primitive rather than a FieldType). JSON_TYPE_OPERATORS: dict[str, set[FilterOperator]] = { "string": { FilterOperator.EQ, @@ -173,61 +239,3 @@ class Not(BaseModel): "array": {FilterOperator.EQ, FilterOperator.IS_NULL}, "object": {FilterOperator.EQ, FilterOperator.IS_NULL}, } - - -def encode_cursor(sort_value: Any, id_value: Any) -> str: - """Encode a cursor as base64 JSON.""" - payload = {"s": sort_value, "id": id_value} - return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() - - -def decode_cursor(cursor: str) -> dict[str, Any]: - """Decode a base64 JSON cursor. Raises ValueError on malformed input.""" - try: - raw = base64.urlsafe_b64decode(cursor.encode()) - data = json.loads(raw) - except Exception as exc: - raise ValueError(f"Malformed cursor: {exc}") from exc - if not isinstance(data, dict) or "s" not in data or "id" not in data: - raise ValueError("Cursor must contain 's' and 'id' keys") - return data - - -class RecordSummary(BaseModel): - srn: RecordSRN - published_at: datetime - metadata: dict[str, Any] - - -class RecordSearchResult(BaseModel): - results: list[RecordSummary] - cursor: str | None - has_more: bool - - -class ColumnInfo(BaseModel): - name: str - type: str - required: bool - - -class FeatureCatalogEntry(BaseModel): - hook_name: str - columns: list[ColumnInfo] - record_count: int - - -class FeatureCatalog(BaseModel): - tables: list[FeatureCatalogEntry] - - -class FeatureRow(BaseModel): - row_id: int - record_srn: RecordSRN - data: dict[str, Any] - - -class FeatureSearchResult(BaseModel): - rows: list[FeatureRow] - cursor: str | None - has_more: bool diff --git a/server/osa/domain/data/model/manifest.py b/server/osa/domain/data/model/manifest.py new file mode 100644 index 0000000..749cb61 --- /dev/null +++ b/server/osa/domain/data/model/manifest.py @@ -0,0 +1,84 @@ +"""Schema manifest response envelope (FR-002, research §9). + +A stable, machine-readable cross-schema shape suitable for SDK code +generation and agent affordance: typed field definitions plus the set of +addressable table resources (the ``records`` table and every feature table), +each with its column schema, row count, and the URL-exposed format suffixes. +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from osa.domain.data.model.query_plan import TableKind +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import SchemaId + + +class FieldSpec(BaseModel): + """A schema-declared metadata field.""" + + name: str + type: FieldType + ontology_id: str | None = None # required iff type == TERM + ontology_version: str | None = None # required iff type == TERM + + +class ColumnSpec(BaseModel): + """A physical column on an addressable table resource.""" + + name: str + type: FieldType + + +# Implicit columns present on every records-table response, in wire order, +# ahead of the schema's declared metadata fields (data-model.md §ColumnSpec). +IMPLICIT_RECORD_COLUMN_SPECS: tuple[ColumnSpec, ...] = ( + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="srn", type=FieldType.TEXT), + ColumnSpec(name="schema_id", type=FieldType.TEXT), + ColumnSpec(name="version", type=FieldType.NUMBER), + ColumnSpec(name="created_at", type=FieldType.DATE), +) + +# Implicit columns present on every feature-table response, in wire order, ahead +# of the hook's declared output columns. Feature rows have no SRN in v1 (data- +# model.md §RecordSummary); ``record_srn`` is the join key back to ``records``. +IMPLICIT_FEATURE_COLUMN_SPECS: tuple[ColumnSpec, ...] = ( + ColumnSpec(name="id", type=FieldType.NUMBER), + ColumnSpec(name="record_srn", type=FieldType.TEXT), + ColumnSpec(name="created_at", type=FieldType.DATE), +) + + +class TableResource(BaseModel): + """One addressable table under a schema: the records table or a feature table.""" + + name: str # "records" or the feature table name + kind: TableKind + columns: list[ColumnSpec] + row_count: int + formats: list[str] # URL suffixes, e.g. ["", "csv", "csv.gz"] + + +class SchemaManifest(BaseModel): + """Full machine-readable manifest for a single schema version.""" + + id: str # short schema id, e.g. "compound" + version: str # SemVer, e.g. "1.0.0" + srn: str # full schema SRN + fields: list[FieldSpec] + table_resources: list[TableResource] + + +class ResolvedTable(BaseModel): + """A table resolved for reading: the owning schema plus its column schema. + + Produced by ``DataCatalogService.resolve_table`` — the single owner of + manifest-structure knowledge (the records resource is named ``records``; + feature resources carry kind ``FEATURE``). Route code consumes this + instead of picking through ``table_resources`` itself. + """ + + schema_id: SchemaId + columns: list[ColumnSpec] diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py new file mode 100644 index 0000000..72e8295 --- /dev/null +++ b/server/osa/domain/data/model/query_plan.py @@ -0,0 +1,175 @@ +"""Query IR for the ``/data/`` read surface. + +``QueryPlan`` is the internal intermediate representation produced by both URL +parsing and POST-body parsing and consumed by the read engine. It is the pure +query description; the *output format* (which serializer, which statement +timeout) is intentionally NOT part of the plan — it is a route/serialization +concern carried alongside the plan, which also avoids a query_plan→format→ +serializer→manifest→query_plan import cycle. + +Also hosts the opaque cursor codec relocated from the discovery domain +(research §3 / FR-041) with no behaviour change. +""" + +from __future__ import annotations + +import base64 +import json +from collections.abc import Mapping +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +from osa.domain.data.model.filter import FilterExpr +from osa.domain.shared.model.ids import HookName +from osa.domain.shared.model.srn import SchemaId + + +class TableKind(StrEnum): + RECORDS = "records" + FEATURE = "feature" + + +class SortDirection(StrEnum): + ASC = "asc" + DESC = "desc" + + +class SortSpec(BaseModel): + """A single sort key — column plus direction (no bare tuples at boundaries).""" + + column: str + direction: SortDirection + + +class PaginationCursor(BaseModel): + """Opaque base64 wrapper around the last row's ``(sort_value, id)`` pair.""" + + value: str + + def __str__(self) -> str: + return self.value + + +class PaginationParams(BaseModel): + cursor: PaginationCursor | None = None + limit: int = Field(default=50, ge=1) + + @classmethod + def clamped( + cls, + *, + cursor: PaginationCursor | None = None, + limit: int, + max_limit: int, + ) -> "PaginationParams": + """Build params with ``limit`` clamped into ``[1, max_limit]``. + + Clamp, don't reject: a consumer asking for "everything" with a big + number gets the max page, not a 422. The ceiling is operator + configuration (``DataConfig.max_page_limit``), so it arrives as an + argument rather than living here as a constant. + """ + return cls(cursor=cursor, limit=max(1, min(limit, max_limit))) + + +class Keyset(BaseModel): + """The keyset-pagination contract for a plan — the single source of truth + for which column is the effective primary sort and which breaks ties. + + Consumed by BOTH sides of the pagination round-trip: the Postgres adapter + (ORDER BY + the cursor's ``after`` condition) and the response assembly + (encoding ``next_cursor`` from the last page row). Keeping the two sides + on one object is what makes encode/decode asymmetries unrepresentable. + """ + + sort_column: str # effective primary sort, post-aliasing (``id`` → tiebreak) + tiebreak_column: str # records: ``srn`` (the PK); features: ``id`` (BIGSERIAL) + + def cursor_from_row(self, row: Mapping[str, Any]) -> str: + """Encode the opaque ``next_cursor`` from the last row of a page.""" + tiebreak = row.get(self.tiebreak_column) + sort_value = ( + tiebreak if self.sort_column == self.tiebreak_column else row.get(self.sort_column) + ) + return encode_cursor(sort_value, tiebreak) + + +# Keyset tiebreak column per table kind. Records tiebreak on ``srn`` — the PK; +# a bare record id is NOT unique across versions. Features tiebreak on the +# BIGSERIAL ``id``. A hook may legally declare a data column named ``srn``, +# which must never hijack the feature tiebreaker — hence by-kind, not by-key. +_TIEBREAK_COLUMNS: dict[TableKind, str] = { + TableKind.RECORDS: "srn", + TableKind.FEATURE: "id", +} + +# Default sort keys per table kind (data-model.md §PaginationParams). +_DEFAULT_SORTS: dict[TableKind, list[SortSpec]] = { + TableKind.RECORDS: [ + SortSpec(column="created_at", direction=SortDirection.DESC), + SortSpec(column="id", direction=SortDirection.DESC), + ], + TableKind.FEATURE: [SortSpec(column="id", direction=SortDirection.ASC)], +} + + +class QueryPlan(BaseModel): + schema_id: SchemaId + table_kind: TableKind + feature_name: HookName | None = None + filter: FilterExpr | None = None + pagination: PaginationParams = Field(default_factory=PaginationParams) + sort: list[SortSpec] = Field(default_factory=list) + + @model_validator(mode="after") + def _validate_and_default(self) -> "QueryPlan": + # feature_name present iff FEATURE + if self.table_kind == TableKind.FEATURE and self.feature_name is None: + raise ValueError("feature_name is required when table_kind is FEATURE") + if self.table_kind == TableKind.RECORDS and self.feature_name is not None: + raise ValueError("feature_name must be None when table_kind is RECORDS") + # Apply default sort per table kind when none was supplied. + if not self.sort: + self.sort = list(_DEFAULT_SORTS[self.table_kind]) + return self + + @property + def keyset(self) -> Keyset: + """The pagination contract for this plan. + + ``sort=id`` aliases to the tiebreak column: for records the addressable + ``id`` is a projection of the srn PK (sorting by it would be ambiguous + across versions), and for features ``id`` already is the tiebreaker. + """ + tiebreak = _TIEBREAK_COLUMNS[self.table_kind] + primary = self.sort[0].column + return Keyset( + sort_column=tiebreak if primary == "id" else primary, + tiebreak_column=tiebreak, + ) + + +def encode_cursor(sort_value: Any, id_value: Any) -> str: + """Encode a cursor as urlsafe base64 of ``{"s": sort_value, "id": id_value}``. + + ``default=str`` because feature rows are raw DB mappings — a datetime sort + value (``created_at``, or a date-time hook column) arrives unrendered. The + sort decoders coerce the ISO string back to a datetime before binding. + """ + payload = {"s": sort_value, "id": id_value} + return base64.urlsafe_b64encode(json.dumps(payload, default=str).encode()).decode() + + +def decode_cursor(cursor: str) -> dict[str, Any]: + """Decode a base64 JSON cursor. Raises ``ValueError`` on malformed input.""" + try: + raw = base64.urlsafe_b64decode(cursor.encode()) + data = json.loads(raw) + except ValueError as exc: + # binascii.Error and json.JSONDecodeError are both ValueError subclasses. + raise ValueError(f"Malformed cursor: {exc}") from exc + if not isinstance(data, dict) or "s" not in data or "id" not in data: + raise ValueError("Cursor must contain 's' and 'id' keys") + return data diff --git a/server/osa/domain/data/model/record_summary.py b/server/osa/domain/data/model/record_summary.py new file mode 100644 index 0000000..cae4507 --- /dev/null +++ b/server/osa/domain/data/model/record_summary.py @@ -0,0 +1,50 @@ +"""Row types yielded by the read engine. + +``RecordSummary`` is the records-table row. Per data-model.md every records +response carries BOTH the bare internal ``id`` and the full ``srn``; ``id`` and +``version`` are derivable from the SRN, ``schema_id`` and ``created_at`` from +the ``records`` table columns. Feature-table rows do not have SRNs in v1 — they +flow through the engine as plain column→value mappings, not ``RecordSummary``. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import RecordSRN, SchemaId + +# Implicit columns present on every records-table response, in wire order, +# ahead of the schema's declared metadata fields. +IMPLICIT_RECORD_COLUMNS: tuple[str, ...] = ("id", "srn", "schema_id", "version", "created_at") + + +class RecordSummary(BaseModel): + """A single published record as projected by the read engine.""" + + id: RecordId + srn: RecordSRN + schema_id: SchemaId + version: int + metadata: dict[str, Any] + created_at: datetime + + def flatten(self) -> dict[str, Any]: + """Flatten into a column→value mapping for serialization. + + Implicit columns come first (in :data:`IMPLICIT_RECORD_COLUMNS` order), + then the metadata fields. Metadata keys never shadow implicit columns + because schema field names cannot collide with the reserved implicit + set at registration time. + """ + return { + "id": str(self.id), + "srn": str(self.srn), + "schema_id": self.schema_id.render(), + "version": self.version, + "created_at": self.created_at.isoformat(), + **self.metadata, + } diff --git a/server/osa/domain/discovery/service/__init__.py b/server/osa/domain/data/port/__init__.py similarity index 100% rename from server/osa/domain/discovery/service/__init__.py rename to server/osa/domain/data/port/__init__.py diff --git a/server/osa/domain/data/port/data_read_store.py b/server/osa/domain/data/port/data_read_store.py new file mode 100644 index 0000000..dc77068 --- /dev/null +++ b/server/osa/domain/data/port/data_read_store.py @@ -0,0 +1,57 @@ +"""Read-store ports feeding the ``/data/`` surface — split along the service seam. + +``DataTableReadStore`` is the streaming primitive behind both the paginated +JSON path and the unbounded CSV/CSV.gz path; the concrete adapter backs it with +a Postgres server-side cursor (research §2) so memory stays bounded regardless +of result size. Rows are yielded as column→value mappings (already projected): +records-table rows include the implicit ``id``/``srn``/``schema_id``/ +``version``/``created_at`` columns plus metadata fields; feature-table rows +carry the hook's declared columns. + +``DataCatalogReadStore`` serves the non-streaming reads: node catalog, schema +manifest, latest-schema resolution, and single-record-by-id. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from osa.domain.data.model.catalog import NodeCatalog + from osa.domain.data.model.manifest import SchemaManifest + from osa.domain.data.model.query_plan import QueryPlan + from osa.domain.data.model.record_summary import RecordSummary + from osa.domain.shared.model.ids import RecordId + from osa.domain.shared.model.srn import SchemaId + + +class DataTableReadStore(Protocol): + def stream_rows( + self, plan: "QueryPlan", timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + """Stream projected rows for the plan via a server-side cursor. + + ``timeout`` is the caller's execution budget for the read; the adapter + enforces it on its own connection (the routes hold no SQL session). + """ + ... + + +class DataCatalogReadStore(Protocol): + async def get_record_by_id(self, id: "RecordId", version: int | None) -> "RecordSummary | None": + """Resolve a single record by bare ID (schema resolved via PK). ``None`` if absent.""" + ... + + async def get_node_catalog(self) -> "NodeCatalog": + """List published schemas with summary table resources.""" + ... + + async def get_schema_manifest(self, schema_id: "SchemaId") -> "SchemaManifest | None": + """Full manifest for a schema. ``None`` if unknown.""" + ... + + async def get_latest_schema_id(self, schema_short_id: str) -> "SchemaId | None": + """Resolve a bare schema id to its latest published version. ``None`` if unknown.""" + ... diff --git a/server/osa/domain/discovery/util/__init__.py b/server/osa/domain/data/query/__init__.py similarity index 100% rename from server/osa/domain/discovery/util/__init__.py rename to server/osa/domain/data/query/__init__.py diff --git a/server/osa/domain/data/query/catalog.py b/server/osa/domain/data/query/catalog.py new file mode 100644 index 0000000..85d864b --- /dev/null +++ b/server/osa/domain/data/query/catalog.py @@ -0,0 +1,48 @@ +"""Catalog-shaped query handlers — node catalog, schema manifest, record by id.""" + +from __future__ import annotations + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import SchemaManifest +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.model.ids import RecordRef +from osa.domain.shared.query import Query, QueryHandler + + +class GetNodeCatalog(Query): + pass + + +class GetNodeCatalogHandler(QueryHandler[GetNodeCatalog, NodeCatalog]): + __auth__ = public() + catalog_service: DataCatalogService + + async def run(self, cmd: GetNodeCatalog) -> NodeCatalog: + return await self.catalog_service.get_node_catalog() + + +class GetSchemaManifest(Query): + schema: str # URL segment: ```` or ``@`` + + +class GetSchemaManifestHandler(QueryHandler[GetSchemaManifest, SchemaManifest]): + __auth__ = public() + catalog_service: DataCatalogService + + async def run(self, cmd: GetSchemaManifest) -> SchemaManifest: + schema_id = await self.catalog_service.resolve_schema(cmd.schema) + return await self.catalog_service.get_schema_manifest(schema_id) + + +class GetDataRecord(Query): + ref: RecordRef + + +class GetDataRecordHandler(QueryHandler[GetDataRecord, RecordSummary]): + __auth__ = public() + catalog_service: DataCatalogService + + async def run(self, cmd: GetDataRecord) -> RecordSummary: + return await self.catalog_service.get_record_by_id(cmd.ref.id, cmd.ref.version) diff --git a/server/osa/domain/data/query/read_table.py b/server/osa/domain/data/query/read_table.py new file mode 100644 index 0000000..38b19d1 --- /dev/null +++ b/server/osa/domain/data/query/read_table.py @@ -0,0 +1,107 @@ +"""Table-read query handlers — one entry point per ``/data/`` table request. + +The handler owns everything between HTTP parsing and row streaming: table +resolution (404 before bytes), plan construction with the config-clamped page +limit, and delegation to the query service. Routes only translate HTTP +primitives into the Query DTO and wrap the result in a response. + +``__auth__ = public()`` is the seam where the deferred private-schema auth +model will land. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass +from datetime import timedelta +from typing import Any + +from pydantic import Field + +from osa.config import Config +from osa.domain.data.model.filter import FilterExpr +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortSpec, + TableKind, +) +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.service.data_query import DataQueryService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.model.ids import HookName +from osa.domain.shared.query import Query, QueryHandler + + +class ReadRecordsTable(Query): + schema: str # URL segment: ```` or ``@`` + filter: FilterExpr | None = None + cursor: str | None = None + limit: int = 50 + sort: list[SortSpec] = Field(default_factory=list) + timeout: timedelta | None = None # execution budget chosen by the response format + + +class ReadFeatureTable(ReadRecordsTable): + feature: HookName + + +@dataclass +class TableRead: + """A resolved table read: the plan (pagination contract), the column + schema (wire order), and the lazily-evaluated row stream.""" + + plan: QueryPlan + columns: list[ColumnSpec] + rows: AsyncIterator[Mapping[str, Any]] + + +def _pagination(cmd: ReadRecordsTable, config: Config) -> PaginationParams: + return PaginationParams.clamped( + cursor=PaginationCursor(value=cmd.cursor) if cmd.cursor else None, + limit=cmd.limit, + max_limit=config.data.max_page_limit, + ) + + +class ReadRecordsTableHandler(QueryHandler[ReadRecordsTable, TableRead]): + __auth__ = public() + catalog_service: DataCatalogService + query_service: DataQueryService + config: Config + + async def run(self, cmd: ReadRecordsTable) -> TableRead: + table = await self.catalog_service.resolve_table(cmd.schema, TableKind.RECORDS) + plan = QueryPlan( + schema_id=table.schema_id, + table_kind=TableKind.RECORDS, + filter=cmd.filter, + pagination=_pagination(cmd, self.config), + sort=cmd.sort, + ) + rows = self.query_service.stream_records(plan, cmd.timeout) + return TableRead(plan=plan, columns=table.columns, rows=rows) + + +class ReadFeatureTableHandler(QueryHandler[ReadFeatureTable, TableRead]): + __auth__ = public() + catalog_service: DataCatalogService + query_service: DataQueryService + config: Config + + async def run(self, cmd: ReadFeatureTable) -> TableRead: + table = await self.catalog_service.resolve_table( + cmd.schema, TableKind.FEATURE, feature_name=cmd.feature + ) + plan = QueryPlan( + schema_id=table.schema_id, + table_kind=TableKind.FEATURE, + feature_name=cmd.feature, + filter=cmd.filter, + pagination=_pagination(cmd, self.config), + sort=cmd.sort, + ) + rows = self.query_service.stream_features(plan, cmd.timeout) + return TableRead(plan=plan, columns=table.columns, rows=rows) diff --git a/server/osa/domain/export/__init__.py b/server/osa/domain/data/service/__init__.py similarity index 100% rename from server/osa/domain/export/__init__.py rename to server/osa/domain/data/service/__init__.py diff --git a/server/osa/domain/data/service/data_catalog.py b/server/osa/domain/data/service/data_catalog.py new file mode 100644 index 0000000..f461b33 --- /dev/null +++ b/server/osa/domain/data/service/data_catalog.py @@ -0,0 +1,104 @@ +"""DataCatalogService — catalog, manifest, and single-record-by-ID reads. + +Read-only orchestration over the :class:`DataCatalogReadStore` port. Catalog + manifest +(US3) and record-by-ID (US4). The reserved-name 404 is enforced defensively +here in addition to the write-side aggregate invariants (research §6). +""" + +from __future__ import annotations + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import ResolvedTable, SchemaManifest +from osa.domain.data.model.query_plan import TableKind +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.data.port.data_read_store import DataCatalogReadStore +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.model.ids import HookName, RecordId +from osa.domain.shared.model.reserved import RESERVED_NAMES +from osa.domain.shared.model.srn import SchemaId +from osa.domain.shared.service import Service + + +class DataCatalogService(Service): + read_store: DataCatalogReadStore + + async def resolve_schema(self, raw: str) -> SchemaId: + """Resolve a URL schema segment (```` or ``@``) to a SchemaId. + + A bare id resolves to the latest published version. Reserved names and + unknown schemas raise ``NotFoundError`` (404 at the route). + """ + short_id = raw.split("@", 1)[0] + if short_id in RESERVED_NAMES: + raise NotFoundError( + f"The path '{short_id}' is reserved under /data/.", code="reserved_name" + ) + if "@" in raw: + try: + return SchemaId.parse(raw) + except ValueError as exc: + raise NotFoundError( + f"No schema '{raw}'. See /api/v1/data for the catalog.", + code="schema_not_found", + ) from exc + resolved = await self.read_store.get_latest_schema_id(short_id) + if resolved is None: + raise NotFoundError( + f"No schema '{raw}'. See /api/v1/data for the catalog.", + code="schema_not_found", + ) + return resolved + + async def get_node_catalog(self) -> NodeCatalog: + return await self.read_store.get_node_catalog() + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest: + # Defense in depth: a reserved id can't be registered (write-side + # invariant) but the read side handles the URL slot deliberately. + if schema_id.id.root in RESERVED_NAMES: + raise NotFoundError( + f"The path '{schema_id.id.root}' is reserved under /data/.", + code="reserved_name", + ) + manifest = await self.read_store.get_schema_manifest(schema_id) + if manifest is None: + raise NotFoundError( + f"No schema '{schema_id.render()}'. See /api/v1/data for the catalog.", + code="schema_not_found", + ) + return manifest + + async def resolve_table( + self, + schema: str, + table_kind: TableKind, + feature_name: HookName | None = None, + ) -> ResolvedTable: + """Resolve a URL schema segment + table selector to its column schema. + + Owns the manifest-structure facts the routes must not re-derive: the + records resource is always named ``records``; a feature resource is + matched by name AND kind (a hook cannot shadow the records slot). + Unknown schema or table raises ``NotFoundError`` (404 before bytes). + """ + schema_id = await self.resolve_schema(schema) + manifest = await self.get_schema_manifest(schema_id) + name = "records" if table_kind == TableKind.RECORDS else feature_name + resource = next( + (tr for tr in manifest.table_resources if tr.name == name and tr.kind == table_kind), + None, + ) + if resource is None: + raise NotFoundError( + f"No table '{name}' on schema '{schema}'. " + f"See /api/v1/data/{schema_id.render()} for its table resources.", + code="table_not_found", + ) + return ResolvedTable(schema_id=schema_id, columns=resource.columns) + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary: + record = await self.read_store.get_record_by_id(id, version) + if record is None: + suffix = f"@{version}" if version is not None else "" + raise NotFoundError(f"No record with id '{id}{suffix}'.", code="record_not_found") + return record diff --git a/server/osa/domain/data/service/data_query.py b/server/osa/domain/data/service/data_query.py new file mode 100644 index 0000000..0f595d8 --- /dev/null +++ b/server/osa/domain/data/service/data_query.py @@ -0,0 +1,100 @@ +"""DataQueryService — streaming read business logic for records and features. + +Validates the filter tree's bounds (depth, predicate count, distinct feature +joins) against config, then delegates row streaming to the +:class:`DataTableReadStore` port. The returned async iterators are wrapped by the +route layer in a serializer + ``StreamingResponse``. Validation runs at the top +of the generator body, so it surfaces on the route's pre-flight ``__anext__`` +pull — before any response bytes (research §4). + +Field/operator-compatibility validation against the resolved table's columns is +performed by the read-store adapter during SQL compilation (where the column +types are known); that path raises ``ValidationError`` before the first row. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator, Mapping +from datetime import timedelta +from typing import Any + +from osa.config import Config +from osa.domain.data.model.filter import And, FeatureFieldRef, FilterExpr, Not, Or, Predicate +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.port.data_read_store import DataTableReadStore +from osa.domain.shared.error import ValidationError +from osa.domain.shared.service import Service + + +class DataQueryService(Service): + read_store: DataTableReadStore + config: Config + + async def stream_records( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + if plan.table_kind != TableKind.RECORDS: + raise ValidationError("stream_records requires a RECORDS plan", field="table_kind") + self._validate_filter_bounds(plan.filter) + async for row in self.read_store.stream_rows(plan, timeout): + yield row + + async def stream_features( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + if plan.table_kind != TableKind.FEATURE: + raise ValidationError("stream_features requires a FEATURE plan", field="table_kind") + self._validate_filter_bounds(plan.filter) + async for row in self.read_store.stream_rows(plan, timeout): + yield row + + # ------------------------------------------------------------------ # + # Filter-tree bounds (ported from DiscoveryService, config-driven) + # ------------------------------------------------------------------ # + + def _validate_filter_bounds(self, expr: FilterExpr | None) -> None: + if expr is None: + return + depth = _tree_depth(expr) + if depth > self.config.data.max_filter_depth: + raise ValidationError( + f"Filter tree depth {depth} exceeds maximum {self.config.data.max_filter_depth}.", + field="filter", + code="filter_depth_exceeded", + ) + predicates = list(_iter_predicates(expr)) + if len(predicates) > self.config.data.max_predicates: + raise ValidationError( + f"Filter tree has {len(predicates)} predicates, exceeds maximum " + f"{self.config.data.max_predicates}.", + field="filter", + code="filter_predicates_exceeded", + ) + distinct_hooks = {p.field.hook for p in predicates if isinstance(p.field, FeatureFieldRef)} + if len(distinct_hooks) > self.config.data.max_feature_joins: + raise ValidationError( + f"Filter joins {len(distinct_hooks)} feature hooks, exceeds maximum " + f"{self.config.data.max_feature_joins}.", + field="filter", + code="filter_joins_exceeded", + ) + + +def _tree_depth(expr: FilterExpr) -> int: + if isinstance(expr, Predicate): + return 1 + if isinstance(expr, Not): + return 1 + _tree_depth(expr.operand) + if isinstance(expr, (And, Or)): + return 1 + max(_tree_depth(op) for op in expr.operands) + return 1 + + +def _iter_predicates(expr: FilterExpr) -> Iterator[Predicate]: + if isinstance(expr, Predicate): + yield expr + elif isinstance(expr, Not): + yield from _iter_predicates(expr.operand) + elif isinstance(expr, (And, Or)): + for op in expr.operands: + yield from _iter_predicates(op) diff --git a/server/osa/domain/export/adapter/__init__.py b/server/osa/domain/data/util/__init__.py similarity index 100% rename from server/osa/domain/export/adapter/__init__.py rename to server/osa/domain/data/util/__init__.py diff --git a/server/osa/domain/data/util/di/__init__.py b/server/osa/domain/data/util/di/__init__.py new file mode 100644 index 0000000..1fd666e --- /dev/null +++ b/server/osa/domain/data/util/di/__init__.py @@ -0,0 +1,3 @@ +from osa.domain.data.util.di.provider import DataProvider + +__all__ = ["DataProvider"] diff --git a/server/osa/domain/data/util/di/provider.py b/server/osa/domain/data/util/di/provider.py new file mode 100644 index 0000000..1c5d236 --- /dev/null +++ b/server/osa/domain/data/util/di/provider.py @@ -0,0 +1,40 @@ +"""Dishka DI provider for the data domain (services + query handlers).""" + +from dishka import provide + +from osa.config import Config +from osa.domain.data.port.data_read_store import ( + DataCatalogReadStore, + DataTableReadStore, +) +from osa.domain.data.query.catalog import ( + GetDataRecordHandler, + GetNodeCatalogHandler, + GetSchemaManifestHandler, +) +from osa.domain.data.query.read_table import ( + ReadFeatureTableHandler, + ReadRecordsTableHandler, +) +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.service.data_query import DataQueryService +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + + +class DataProvider(Provider): + @provide(scope=Scope.UOW) + def get_data_query_service( + self, read_store: DataTableReadStore, config: Config + ) -> DataQueryService: + return DataQueryService(read_store=read_store, config=config) + + @provide(scope=Scope.UOW) + def get_data_catalog_service(self, read_store: DataCatalogReadStore) -> DataCatalogService: + return DataCatalogService(read_store=read_store) + + read_records_table_handler = provide(ReadRecordsTableHandler, scope=Scope.UOW) + read_feature_table_handler = provide(ReadFeatureTableHandler, scope=Scope.UOW) + get_node_catalog_handler = provide(GetNodeCatalogHandler, scope=Scope.UOW) + get_schema_manifest_handler = provide(GetSchemaManifestHandler, scope=Scope.UOW) + get_data_record_handler = provide(GetDataRecordHandler, scope=Scope.UOW) diff --git a/server/osa/domain/discovery/__init__.py b/server/osa/domain/discovery/__init__.py deleted file mode 100644 index c6d5d02..0000000 --- a/server/osa/domain/discovery/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Discovery domain — read-only search and filter API for records and features.""" diff --git a/server/osa/domain/discovery/model/refs.py b/server/osa/domain/discovery/model/refs.py deleted file mode 100644 index f5783a0..0000000 --- a/server/osa/domain/discovery/model/refs.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Typed field references used inside Predicate.field. - -Two kinds of references are supported: - -- :class:`MetadataFieldRef` — resolves to a column in ``metadata._v``. -- :class:`FeatureFieldRef` — resolves to a column in ``features.``. - -Wire format is a dotted path (``metadata.`` or -``features..``). :func:`parse_field_ref` parses the wire form -into a typed reference and validates identifier shape. -""" - -from __future__ import annotations - -import re -from typing import Literal, Union - -from pydantic import BaseModel - -_IDENT = re.compile(r"^[a-z][a-z0-9_]*$") - - -class MetadataFieldRef(BaseModel): - path: Literal["metadata"] = "metadata" - field: str - - def dotted(self) -> str: - return f"metadata.{self.field}" - - -class FeatureFieldRef(BaseModel): - path: Literal["features"] = "features" - hook: str - column: str - - def dotted(self) -> str: - return f"features.{self.hook}.{self.column}" - - -FieldRef = Union[MetadataFieldRef, FeatureFieldRef] - - -def parse_field_ref(dotted: str) -> FieldRef: - """Parse a dotted-path field reference into its typed form. - - Raises :class:`ValueError` when the path shape or identifier doesn't match - the documented grammar. - """ - if not isinstance(dotted, str): - raise ValueError(f"Expected dotted string, got {type(dotted).__name__}") - - parts = dotted.split(".") - if not parts: - raise ValueError(f"Empty field reference: {dotted!r}") - - head = parts[0] - if head == "metadata": - if len(parts) != 2: - raise ValueError(f"metadata.* refs must be exactly two dotted parts, got {dotted!r}") - field = parts[1] - if not _IDENT.match(field): - raise ValueError(f"Invalid metadata field identifier: {field!r}") - return MetadataFieldRef(field=field) - - if head == "features": - if len(parts) != 3: - raise ValueError(f"features.* refs must be exactly three dotted parts, got {dotted!r}") - hook, column = parts[1], parts[2] - if not _IDENT.match(hook): - raise ValueError(f"Invalid hook identifier: {hook!r}") - if not _IDENT.match(column): - raise ValueError(f"Invalid feature column identifier: {column!r}") - return FeatureFieldRef(hook=hook, column=column) - - raise ValueError( - f"Unknown field reference prefix {head!r} in {dotted!r}. " - "Expected 'metadata.' or 'features..'." - ) diff --git a/server/osa/domain/discovery/port/field_definition_reader.py b/server/osa/domain/discovery/port/field_definition_reader.py deleted file mode 100644 index c2a6bfa..0000000 --- a/server/osa/domain/discovery/port/field_definition_reader.py +++ /dev/null @@ -1,28 +0,0 @@ -"""FieldDefinitionReader port — cross-domain read port for schema field lookups.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from osa.domain.semantics.model.value import FieldType - from osa.domain.shared.model.srn import SchemaId - - -class FieldDefinitionReader(Protocol): - async def get_all_field_types(self) -> dict[str, FieldType]: - """Return global field_name -> FieldType map across all schemas. - - Raises ValidationError if same field name has conflicting types across schemas. - """ - ... - - async def get_fields_for_schema(self, schema_id: "SchemaId") -> dict[str, FieldType]: - """Return field_name -> FieldType for a specific schema's current major version. - - Returns an empty dict when the schema is unknown to the node. Callers - that treat "unknown schema" as an error condition must check for an - empty map and raise ``NotFoundError`` themselves — the port stays - neutral so that non-user-facing callers can handle absence explicitly. - """ - ... diff --git a/server/osa/domain/discovery/port/read_store.py b/server/osa/domain/discovery/port/read_store.py deleted file mode 100644 index 762364e..0000000 --- a/server/osa/domain/discovery/port/read_store.py +++ /dev/null @@ -1,50 +0,0 @@ -"""DiscoveryReadStore port — read-only access to records, features, metadata.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from osa.domain.discovery.model.value import ( - FeatureCatalogEntry, - FeatureRow, - FilterExpr, - RecordSummary, - SortOrder, - ) - from osa.domain.semantics.model.value import FieldType - from osa.domain.shared.model.srn import ConventionSRN, RecordSRN, SchemaId - - -class DiscoveryReadStore(Protocol): - async def search_records( - self, - filter_expr: "FilterExpr | None", - schema_id: "SchemaId | None", - convention_srn: "ConventionSRN | None", - text_fields: list[str], - q: str | None, - sort: str, - order: "SortOrder", - cursor: dict | None, - limit: int, - field_types: "dict[str, FieldType] | None" = None, - ) -> "list[RecordSummary]": - """Search published records with a compound filter.""" - ... - - async def get_feature_catalog(self) -> "list[FeatureCatalogEntry]": ... - - async def get_feature_table_schema(self, hook_name: str) -> "FeatureCatalogEntry | None": ... - - async def search_features( - self, - hook_name: str, - filter_expr: "FilterExpr | None", - schema_id: "SchemaId | None", - record_srn: "RecordSRN | None", - sort: str, - order: "SortOrder", - cursor: dict | None, - limit: int, - ) -> "list[FeatureRow]": ... diff --git a/server/osa/domain/discovery/query/get_feature_catalog.py b/server/osa/domain/discovery/query/get_feature_catalog.py deleted file mode 100644 index 30f891d..0000000 --- a/server/osa/domain/discovery/query/get_feature_catalog.py +++ /dev/null @@ -1,32 +0,0 @@ -"""GetFeatureCatalog query — list available feature tables with schemas and counts.""" - -from osa.domain.discovery.model.value import FeatureCatalog -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.authorization.gate import public -from osa.domain.shared.query import Query, QueryHandler, Result - - -class GetFeatureCatalog(Query): - pass - - -class GetFeatureCatalogResult(Result): - tables: list[dict] - - -class GetFeatureCatalogHandler(QueryHandler[GetFeatureCatalog, GetFeatureCatalogResult]): - __auth__ = public() - discovery_service: DiscoveryService - - async def run(self, cmd: GetFeatureCatalog) -> GetFeatureCatalogResult: - catalog: FeatureCatalog = await self.discovery_service.get_feature_catalog() - return GetFeatureCatalogResult( - tables=[ - { - "hook_name": entry.hook_name, - "columns": [c.model_dump() for c in entry.columns], - "record_count": entry.record_count, - } - for entry in catalog.tables - ] - ) diff --git a/server/osa/domain/discovery/query/search_features.py b/server/osa/domain/discovery/query/search_features.py deleted file mode 100644 index 42dde9a..0000000 --- a/server/osa/domain/discovery/query/search_features.py +++ /dev/null @@ -1,57 +0,0 @@ -"""SearchFeatures query — query and filter rows in a specific feature table.""" - -from osa.domain.discovery.model.value import ( - FeatureSearchResult, - FilterExpr, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.authorization.gate import public -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.domain.shared.query import Query, QueryHandler, Result - - -class SearchFeatures(Query): - hook_name: str - filter_expr: FilterExpr | None = None - schema_id: SchemaId | None = None - record_srn: str | None = None - sort: str = "id" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = 50 - - -class SearchFeaturesResult(Result): - rows: list[dict] - cursor: str | None - has_more: bool - - -class SearchFeaturesHandler(QueryHandler[SearchFeatures, SearchFeaturesResult]): - __auth__ = public() - discovery_service: DiscoveryService - - async def run(self, cmd: SearchFeatures) -> SearchFeaturesResult: - record_srn: RecordSRN | None = None - if cmd.record_srn: - try: - record_srn = RecordSRN.parse(cmd.record_srn) - except ValueError as exc: - raise ValidationError(str(exc), field="record_srn") from exc - result: FeatureSearchResult = await self.discovery_service.search_features( - hook_name=cmd.hook_name, - filter_expr=cmd.filter_expr, - schema_id=cmd.schema_id, - record_srn=record_srn, - sort=cmd.sort, - order=cmd.order, - cursor=cmd.cursor, - limit=cmd.limit, - ) - return SearchFeaturesResult( - rows=[{"record_srn": str(r.record_srn), **r.data} for r in result.rows], - cursor=result.cursor, - has_more=result.has_more, - ) diff --git a/server/osa/domain/discovery/query/search_records.py b/server/osa/domain/discovery/query/search_records.py deleted file mode 100644 index 515d980..0000000 --- a/server/osa/domain/discovery/query/search_records.py +++ /dev/null @@ -1,59 +0,0 @@ -"""SearchRecords query — search and filter published records.""" - -from typing import Any - -from osa.domain.discovery.model.value import ( - FilterExpr, - RecordSearchResult, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.authorization.gate import public -from osa.domain.shared.model.srn import ConventionSRN, SchemaId -from osa.domain.shared.query import Query, QueryHandler, Result - - -class SearchRecords(Query): - filter_expr: FilterExpr | None = None - schema_id: SchemaId | None = None - convention_srn: ConventionSRN | None = None - q: str | None = None - sort: str = "published_at" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = 20 - - -class SearchRecordsResult(Result): - results: list[dict[str, Any]] - cursor: str | None - has_more: bool - - -class SearchRecordsHandler(QueryHandler[SearchRecords, SearchRecordsResult]): - __auth__ = public() - discovery_service: DiscoveryService - - async def run(self, cmd: SearchRecords) -> SearchRecordsResult: - result: RecordSearchResult = await self.discovery_service.search_records( - filter_expr=cmd.filter_expr, - schema_id=cmd.schema_id, - convention_srn=cmd.convention_srn, - q=cmd.q, - sort=cmd.sort, - order=cmd.order, - cursor=cmd.cursor, - limit=cmd.limit, - ) - return SearchRecordsResult( - results=[ - { - "srn": str(r.srn), - "published_at": r.published_at.isoformat(), - "metadata": r.metadata, - } - for r in result.results - ], - cursor=result.cursor, - has_more=result.has_more, - ) diff --git a/server/osa/domain/discovery/service/discovery.py b/server/osa/domain/discovery/service/discovery.py deleted file mode 100644 index cd8db52..0000000 --- a/server/osa/domain/discovery/service/discovery.py +++ /dev/null @@ -1,423 +0,0 @@ -"""DiscoveryService — read-only business logic for record and feature search. - -Validates the compound ``FilterExpr`` tree (bounds, field resolution, operator -compatibility) before handing it to the read store for SQL compilation. -""" - -from __future__ import annotations - -import logging -from typing import Any - -from osa.config import Config -from osa.domain.discovery.model.refs import ( - FeatureFieldRef, - MetadataFieldRef, -) -from osa.domain.discovery.model.value import ( - JSON_TYPE_OPERATORS, - VALID_OPERATORS, - And, - FeatureCatalog, - FeatureSearchResult, - FilterExpr, - FilterOperator, - Not, - Or, - Predicate, - RecordSearchResult, - SortOrder, - decode_cursor, - encode_cursor, -) -from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader -from osa.domain.discovery.port.read_store import DiscoveryReadStore -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import NotFoundError, ValidationError -from osa.domain.shared.model.srn import ConventionSRN, RecordSRN, SchemaId -from osa.domain.shared.service import Service - -logger = logging.getLogger(__name__) - - -class DiscoveryService(Service): - """Orchestrates validation and delegation for discovery queries.""" - - read_store: DiscoveryReadStore - field_reader: FieldDefinitionReader - config: Config - - async def search_records( - self, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - convention_srn: ConventionSRN | None, - q: str | None, - sort: str, - order: SortOrder, - cursor: str | None, - limit: int, - *, - allow_compound: bool = True, - ) -> RecordSearchResult: - """Validate the filter tree and delegate record search to the read store. - - ``allow_compound`` is a staged flag — US1 delivers AND-only + Predicate - support; US2 flips this to allow OR/NOT. Callers should leave it True - once US2 lands. - """ - if limit < 1 or limit > 100: - raise ValidationError("limit must be between 1 and 100", field="limit") - - if sort != "published_at" and schema_id is None: - raise ValidationError( - f"Sorting by '{sort}' requires the request to pin a 'schema' " - "('@'). Plain listings must sort by 'published_at'.", - field="sort", - code="schema_required_for_metadata_sort", - ) - - if q and schema_id is None: - raise ValidationError( - "Free-text search ('q') requires the request to pin a 'schema' " - "('@'). Without a schema, the server cannot resolve " - "which metadata fields are text-indexed.", - field="q", - code="schema_required_for_free_text_search", - ) - - schema_field_map: dict[str, FieldType] = {} - if schema_id is not None: - schema_field_map = await self.field_reader.get_fields_for_schema(schema_id) - if not schema_field_map: - raise NotFoundError( - f"Schema not found: {schema_id.render()}. " - "Pin an '@' that matches a registered schema." - ) - - if filter_expr is not None: - self._validate_tree(filter_expr, allow_compound=allow_compound) - await self._validate_refs(filter_expr, schema_id, schema_field_map) - - # Sort field validation (against pinned schema) - if sort != "published_at" and sort not in schema_field_map: - raise ValidationError( - f"Unknown sort field '{sort}': not defined in the pinned schema.", - field="sort", - code="unknown_sort_field", - ) - - decoded_cursor: dict[str, Any] | None = None - if cursor is not None: - try: - decoded_cursor = decode_cursor(cursor) - except ValueError as exc: - raise ValidationError(str(exc), field="cursor") from exc - - text_fields = [ - name for name, ft in schema_field_map.items() if ft in (FieldType.TEXT, FieldType.URL) - ] - if q and not text_fields: - raise ValidationError( - "Free-text search is unavailable: the pinned schema defines no text or URL fields.", - field="q", - code="no_text_fields_in_schema", - ) - - results = await self.read_store.search_records( - filter_expr=filter_expr, - schema_id=schema_id, - convention_srn=convention_srn, - text_fields=text_fields, - q=q, - sort=sort, - order=order, - cursor=decoded_cursor, - limit=limit + 1, - field_types=schema_field_map, - ) - - has_more = len(results) > limit - results = results[:limit] - next_cursor = None - if has_more and results: - last = results[-1] - if sort == "published_at": - sort_val = last.published_at.isoformat() - else: - sort_val = last.metadata.get(sort) - next_cursor = encode_cursor(sort_val, str(last.srn)) - - return RecordSearchResult( - results=results, - cursor=next_cursor, - has_more=has_more, - ) - - async def get_feature_catalog(self) -> FeatureCatalog: - entries = await self.read_store.get_feature_catalog() - return FeatureCatalog(tables=entries) - - async def search_features( - self, - hook_name: str, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - record_srn: RecordSRN | None, - sort: str, - order: SortOrder, - cursor: str | None, - limit: int, - *, - allow_compound: bool = True, - ) -> FeatureSearchResult: - if limit < 1 or limit > 100: - raise ValidationError("limit must be between 1 and 100", field="limit") - - entry = await self.read_store.get_feature_table_schema(hook_name) - if entry is None: - raise NotFoundError(f"Feature table not found: {hook_name}") - - col_map: dict[str, str] = {col.name: col.type for col in entry.columns} - col_map["record_srn"] = "string" - - schema_field_map: dict[str, FieldType] = {} - if schema_id is not None: - schema_field_map = await self.field_reader.get_fields_for_schema(schema_id) - if not schema_field_map: - raise NotFoundError( - f"Schema not found: {schema_id.render()}. " - "Pin an '@' that matches a registered schema." - ) - - if filter_expr is not None: - self._validate_tree(filter_expr, allow_compound=allow_compound) - self._validate_feature_refs( - filter_expr, - this_hook=hook_name, - feature_col_map=col_map, - schema_field_map=schema_field_map, - ) - - if sort != "id" and sort not in col_map: - raise ValidationError( - f"Unknown sort column '{sort}' in feature table '{hook_name}'", - field="sort", - ) - - try: - decoded_cursor = decode_cursor(cursor) if cursor else None - except ValueError as exc: - raise ValidationError(str(exc), field="cursor") from exc - - rows = await self.read_store.search_features( - hook_name=hook_name, - filter_expr=filter_expr, - schema_id=schema_id, - record_srn=record_srn, - sort=sort, - order=order, - cursor=decoded_cursor, - limit=limit + 1, - ) - - has_more = len(rows) > limit - rows = rows[:limit] - next_cursor = None - if has_more and rows: - last = rows[-1] - if sort == "id": - sort_val = last.row_id - else: - sort_val = last.data.get(sort) - next_cursor = encode_cursor(sort_val, last.row_id) - - return FeatureSearchResult(rows=rows, cursor=next_cursor, has_more=has_more) - - # ------------------------- internal helpers ------------------------- - - def _validate_tree(self, expr: FilterExpr, *, allow_compound: bool) -> None: - """Enforce tree bounds (depth, predicate count, joins) + compound gating.""" - depth = _tree_depth(expr) - predicates = list(_iter_predicates(expr)) - - if depth > self.config.discovery_max_filter_depth: - raise ValidationError( - f"Filter tree depth {depth} exceeds configured maximum " - f"{self.config.discovery_max_filter_depth} (OSA_DISCOVERY_MAX_FILTER_DEPTH).", - field="filter", - code="filter_depth_exceeded", - ) - if len(predicates) > self.config.discovery_max_predicates: - raise ValidationError( - f"Filter tree has {len(predicates)} predicate leaves, exceeds " - f"configured maximum {self.config.discovery_max_predicates} " - "(OSA_DISCOVERY_MAX_PREDICATES).", - field="filter", - code="filter_predicates_exceeded", - ) - - distinct_hooks: set[str] = set() - for p in predicates: - if isinstance(p.field, FeatureFieldRef): - distinct_hooks.add(p.field.hook) - if len(distinct_hooks) > self.config.discovery_max_cross_domain_joins: - raise ValidationError( - f"Filter tree joins {len(distinct_hooks)} distinct feature hooks, " - f"exceeds configured maximum " - f"{self.config.discovery_max_cross_domain_joins} " - "(OSA_DISCOVERY_MAX_CROSS_DOMAIN_JOINS).", - field="filter", - code="filter_joins_exceeded", - ) - - if not allow_compound: - for node in _iter_nodes(expr): - if isinstance(node, (Or, Not)): - raise ValidationError( - "Compound OR/NOT filters are not enabled in this build.", - field="filter", - code="compound_disabled", - ) - - async def _validate_refs( - self, - expr: FilterExpr, - schema_id: SchemaId | None, - field_map: dict[str, FieldType], - ) -> None: - """Resolve each predicate's field and check operator compatibility.""" - feature_catalog: dict[str, dict[str, str]] | None = None - for p in _iter_predicates(expr): - if isinstance(p.field, MetadataFieldRef): - if schema_id is None: - raise ValidationError( - f"Metadata predicate on {p.field.dotted()!r} requires " - "the request to pin a 'schema' ('@'). " - "Unscoped metadata filtering is not supported — the typed " - "metadata table is the only filter path.", - field=p.field.dotted(), - code="schema_required_for_metadata_query", - ) - field_name = p.field.field - if field_name not in field_map: - raise ValidationError( - f"Unknown metadata field '{field_name}' for the pinned schema.", - field=p.field.dotted(), - code="unknown_field", - ) - self._check_operator_for_field_type( - p, field_type=field_map[field_name], path=p.field.dotted() - ) - elif isinstance(p.field, FeatureFieldRef): - if feature_catalog is None: - feature_catalog = await self._load_feature_catalog() - cols = feature_catalog.get(p.field.hook) - if cols is None: - raise ValidationError( - f"Unknown feature hook '{p.field.hook}'.", - field=p.field.dotted(), - code="unknown_hook", - ) - if p.field.column not in cols: - raise ValidationError( - f"Unknown feature column '{p.field.column}' on hook '{p.field.hook}'.", - field=p.field.dotted(), - code="unknown_field", - ) - json_type = cols[p.field.column] - self._check_operator_for_json_type(p, json_type=json_type, path=p.field.dotted()) - - def _validate_feature_refs( - self, - expr: FilterExpr, - *, - this_hook: str, - feature_col_map: dict[str, str], - schema_field_map: dict[str, FieldType], - ) -> None: - """Variant of ref validation for feature search — local hook columns by default.""" - for p in _iter_predicates(expr): - if isinstance(p.field, MetadataFieldRef): - if p.field.field not in schema_field_map: - raise ValidationError( - f"Unknown metadata field '{p.field.field}' for the provided schema.", - field=p.field.dotted(), - code="unknown_field", - ) - self._check_operator_for_field_type( - p, field_type=schema_field_map[p.field.field], path=p.field.dotted() - ) - elif isinstance(p.field, FeatureFieldRef): - if p.field.hook != this_hook: - # Cross-hook joins handled by US3 — accepted here, resolved in adapter. - continue - if p.field.column not in feature_col_map: - raise ValidationError( - f"Unknown feature column '{p.field.column}' on hook '{this_hook}'.", - field=p.field.dotted(), - code="unknown_field", - ) - self._check_operator_for_json_type( - p, json_type=feature_col_map[p.field.column], path=p.field.dotted() - ) - - async def _load_feature_catalog(self) -> dict[str, dict[str, str]]: - """Build hook_name → {column_name → json_type} map from the catalog.""" - catalog = await self.read_store.get_feature_catalog() - return {entry.hook_name: {col.name: col.type for col in entry.columns} for entry in catalog} - - @staticmethod - def _check_operator_for_field_type( - predicate: Predicate, *, field_type: FieldType, path: str - ) -> None: - valid = VALID_OPERATORS.get(field_type, set()) - if predicate.op not in valid: - raise ValidationError( - f"Operator '{predicate.op}' is not valid for field '{path}' " - f"(type '{field_type}'). Valid: {sorted(valid)}.", - field=path, - code="operator_not_valid_for_type", - ) - - @staticmethod - def _check_operator_for_json_type(predicate: Predicate, *, json_type: str, path: str) -> None: - valid = JSON_TYPE_OPERATORS.get(json_type, {FilterOperator.EQ}) - if predicate.op not in valid: - raise ValidationError( - f"Operator '{predicate.op}' is not valid for column '{path}' " - f"(json type '{json_type}'). Valid: {sorted(valid)}.", - field=path, - code="operator_not_valid_for_type", - ) - - -def _tree_depth(expr: FilterExpr) -> int: - if isinstance(expr, Predicate): - return 1 - if isinstance(expr, Not): - return 1 + _tree_depth(expr.operand) - if isinstance(expr, (And, Or)): - return 1 + max(_tree_depth(op) for op in expr.operands) - return 1 - - -def _iter_predicates(expr: FilterExpr): - if isinstance(expr, Predicate): - yield expr - return - if isinstance(expr, Not): - yield from _iter_predicates(expr.operand) - return - if isinstance(expr, (And, Or)): - for op in expr.operands: - yield from _iter_predicates(op) - - -def _iter_nodes(expr: FilterExpr): - yield expr - if isinstance(expr, Not): - yield from _iter_nodes(expr.operand) - elif isinstance(expr, (And, Or)): - for op in expr.operands: - yield from _iter_nodes(op) diff --git a/server/osa/domain/discovery/util/di/__init__.py b/server/osa/domain/discovery/util/di/__init__.py deleted file mode 100644 index 8672799..0000000 --- a/server/osa/domain/discovery/util/di/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from osa.domain.discovery.util.di.provider import DiscoveryProvider - -__all__ = ["DiscoveryProvider"] diff --git a/server/osa/domain/discovery/util/di/provider.py b/server/osa/domain/discovery/util/di/provider.py deleted file mode 100644 index 9715b8c..0000000 --- a/server/osa/domain/discovery/util/di/provider.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Dishka DI provider for the discovery domain.""" - -from dishka import provide - -from osa.config import Config -from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader -from osa.domain.discovery.port.read_store import DiscoveryReadStore -from osa.domain.discovery.query.get_feature_catalog import GetFeatureCatalogHandler -from osa.domain.discovery.query.search_features import SearchFeaturesHandler -from osa.domain.discovery.query.search_records import SearchRecordsHandler -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.util.di.base import Provider -from osa.util.di.scope import Scope - - -class DiscoveryProvider(Provider): - @provide(scope=Scope.UOW) - def get_discovery_service( - self, - read_store: DiscoveryReadStore, - field_reader: FieldDefinitionReader, - config: Config, - ) -> DiscoveryService: - return DiscoveryService( - read_store=read_store, - field_reader=field_reader, - config=config, - ) - - # Query Handlers - search_records_handler = provide(SearchRecordsHandler, scope=Scope.UOW) - get_feature_catalog_handler = provide(GetFeatureCatalogHandler, scope=Scope.UOW) - search_features_handler = provide(SearchFeaturesHandler, scope=Scope.UOW) diff --git a/server/osa/domain/index/__init__.py b/server/osa/domain/index/__init__.py deleted file mode 100644 index 96aacc8..0000000 --- a/server/osa/domain/index/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Index domain - configuration and orchestration for storage backends.""" diff --git a/server/osa/domain/index/event/__init__.py b/server/osa/domain/index/event/__init__.py deleted file mode 100644 index b350a4e..0000000 --- a/server/osa/domain/index/event/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Index domain events.""" - -from osa.domain.index.event.index_record import IndexRecord - -__all__ = ["IndexRecord"] diff --git a/server/osa/domain/index/event/index_record.py b/server/osa/domain/index/event/index_record.py deleted file mode 100644 index 7f438e0..0000000 --- a/server/osa/domain/index/event/index_record.py +++ /dev/null @@ -1,26 +0,0 @@ -"""IndexRecord event - per-backend indexing request for a single record.""" - -from typing import Any - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import RecordSRN - - -class IndexRecord(Event): - """Request to index a single record into a specific backend. - - This event is created by FanOutToIndexBackends when a RecordPublished - event is received. Each backend gets its own IndexRecord event, - enabling independent retry and failure isolation. - - Attributes: - id: Unique event identifier (inherited from Event). - backend_name: Target backend name (e.g., "vector", "keyword"). - record_srn: Structured Resource Name of the record. - metadata: Record metadata to index. - """ - - id: EventId - backend_name: str - record_srn: RecordSRN - metadata: dict[str, Any] diff --git a/server/osa/domain/index/handler/__init__.py b/server/osa/domain/index/handler/__init__.py deleted file mode 100644 index 75cfab4..0000000 --- a/server/osa/domain/index/handler/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Index domain event handlers.""" - -from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends -from osa.domain.index.handler.keyword_index_handler import KeywordIndexHandler -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler - -__all__ = ["FanOutToIndexBackends", "KeywordIndexHandler", "VectorIndexHandler"] diff --git a/server/osa/domain/index/handler/fanout_to_index_backends.py b/server/osa/domain/index/handler/fanout_to_index_backends.py deleted file mode 100644 index 4cb265d..0000000 --- a/server/osa/domain/index/handler/fanout_to_index_backends.py +++ /dev/null @@ -1,47 +0,0 @@ -"""FanOutToIndexBackends - creates per-backend IndexRecord events from RecordPublished.""" - -import logging -from typing import ClassVar -from uuid import uuid4 - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventHandler, EventId -from osa.domain.shared.outbox import Outbox - -logger = logging.getLogger(__name__) - - -class FanOutToIndexBackends(EventHandler[RecordPublished]): - """Creates per-backend IndexRecord events from RecordPublished. - - When a record is published, this handler creates one IndexRecord event - per registered backend. Each IndexRecord is stored in the outbox, - enabling independent retry and failure isolation per backend. - - This replaces the previous pattern where a single RecordPublished event - triggered immediate indexing to all backends in a single transaction. - - Batch processing is used for efficiency when multiple records are - published in quick succession. - """ - - __batch_size__: ClassVar[int] = 10 - - indexes: IndexRegistry - outbox: Outbox - - async def handle(self, event: RecordPublished) -> None: - """Create IndexRecord events for each registered backend.""" - backend_names = list(self.indexes) - logger.debug(f"FanOut: {event.record_srn} -> {len(backend_names)} backends") - - for backend_name in backend_names: - index_event = IndexRecord( - id=EventId(uuid4()), - backend_name=backend_name, - record_srn=event.record_srn, - metadata=event.metadata, - ) - await self.outbox.append(index_event) diff --git a/server/osa/domain/index/handler/keyword_index_handler.py b/server/osa/domain/index/handler/keyword_index_handler.py deleted file mode 100644 index a6f5791..0000000 --- a/server/osa/domain/index/handler/keyword_index_handler.py +++ /dev/null @@ -1,38 +0,0 @@ -"""KeywordIndexHandler - processes IndexRecord events for keyword backends.""" - -import logging -from typing import ClassVar - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.error import SkippedEvents -from osa.domain.shared.event import EventHandler - -logger = logging.getLogger(__name__) - - -class KeywordIndexHandler(EventHandler[IndexRecord]): - """Processes IndexRecord events for the keyword backend. - - Processes events immediately (batch_size=1) since keyword indexing - doesn't benefit from batching. - """ - - __batch_size__: ClassVar[int] = 1 - - indexes: IndexRegistry - - async def handle(self, event: IndexRecord) -> None: - """Process a single IndexRecord event.""" - backend = self.indexes.get("keyword") - if backend is None: - raise SkippedEvents( - event_ids=[event.id], - reason="Keyword backend not available", - ) - - record = (str(event.record_srn), event.metadata) - - await backend.ingest_batch([record]) - - logger.debug(f"KeywordIndexHandler: indexed event {event.id}") diff --git a/server/osa/domain/index/handler/vector_index_handler.py b/server/osa/domain/index/handler/vector_index_handler.py deleted file mode 100644 index aa3d293..0000000 --- a/server/osa/domain/index/handler/vector_index_handler.py +++ /dev/null @@ -1,47 +0,0 @@ -"""VectorIndexHandler - processes IndexRecord events for vector backends.""" - -import logging -from typing import ClassVar - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.error import SkippedEvents -from osa.domain.shared.event import EventHandler - -logger = logging.getLogger(__name__) - - -class VectorIndexHandler(EventHandler[IndexRecord]): - """Processes IndexRecord events for the vector backend. - - Processes events in batches for efficient embedding generation. - """ - - __batch_size__: ClassVar[int] = 100 - __batch_timeout__: ClassVar[float] = 5.0 - - indexes: IndexRegistry - - async def handle_batch(self, events: list[IndexRecord]) -> None: - """Process a batch of IndexRecord events. - - Converts events to records and calls ingest_batch on the backend. - """ - if not events: - return - - backend = self.indexes.get("vector") - if backend is None: - raise SkippedEvents( - event_ids=[e.id for e in events], - reason="Vector backend not available", - ) - - # Prepare records for batch ingestion - records = [(str(e.record_srn), e.metadata) for e in events] - - logger.debug(f"VectorIndexHandler: ingesting {len(records)} records to backend") - - await backend.ingest_batch(records) - - logger.debug(f"VectorIndexHandler: ingested {len(events)} records") diff --git a/server/osa/domain/index/model/__init__.py b/server/osa/domain/index/model/__init__.py deleted file mode 100644 index 4f64a80..0000000 --- a/server/osa/domain/index/model/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Index domain models.""" - -from osa.domain.index.model.registry import IndexRegistry - -__all__ = ["IndexRegistry"] diff --git a/server/osa/domain/index/model/registry.py b/server/osa/domain/index/model/registry.py deleted file mode 100644 index 9bb01bf..0000000 --- a/server/osa/domain/index/model/registry.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Index registry - typed container for available storage backends.""" - -from collections.abc import Iterator - -from osa.sdk.index.backend import StorageBackend - - -class IndexRegistry: - """Registry of available index backends.""" - - def __init__(self, backends: dict[str, StorageBackend]) -> None: - self._backends = backends - - def get(self, name: str) -> StorageBackend | None: - """Get a backend by name.""" - return self._backends.get(name) - - def __contains__(self, name: str) -> bool: - return name in self._backends - - def __iter__(self) -> Iterator[str]: - return iter(self._backends) - - def __len__(self) -> int: - return len(self._backends) - - def names(self) -> list[str]: - """List all available index names.""" - return list(self._backends.keys()) - - def items(self) -> Iterator[tuple[str, StorageBackend]]: - """Iterate over (name, backend) pairs.""" - return iter(self._backends.items()) diff --git a/server/osa/domain/index/service/__init__.py b/server/osa/domain/index/service/__init__.py deleted file mode 100644 index 852bec9..0000000 --- a/server/osa/domain/index/service/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Index service module.""" - -from osa.domain.index.service.index import IndexService - -__all__ = ["IndexService"] diff --git a/server/osa/domain/index/service/index.py b/server/osa/domain/index/service/index.py deleted file mode 100644 index b4a485b..0000000 --- a/server/osa/domain/index/service/index.py +++ /dev/null @@ -1,48 +0,0 @@ -"""IndexService - orchestrates indexing of records into storage backends.""" - -import logging - -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.service import Service - -logger = logging.getLogger(__name__) - - -class IndexService(Service): - """Service for index-related operations. - - Note: Direct indexing (index_record, flush_all) has been replaced by the - event-driven approach using FanOutToIndexBackends and IndexRecordBatch - listeners. This service is retained for query operations and future - index management commands. - """ - - indexes: IndexRegistry - - async def get_count(self, backend_name: str) -> int | None: - """Get the document count for a specific backend. - - Args: - backend_name: Name of the backend to query. - - Returns: - Document count, or None if backend not found. - """ - backend = self.indexes.get(backend_name) - if backend is None: - return None - return await backend.count() - - async def check_health(self, backend_name: str) -> bool | None: - """Check health of a specific backend. - - Args: - backend_name: Name of the backend to check. - - Returns: - True if healthy, False if unhealthy, None if backend not found. - """ - backend = self.indexes.get(backend_name) - if backend is None: - return None - return await backend.health() diff --git a/server/osa/domain/record/query/get_stats.py b/server/osa/domain/record/query/get_stats.py new file mode 100644 index 0000000..92b04f9 --- /dev/null +++ b/server/osa/domain/record/query/get_stats.py @@ -0,0 +1,21 @@ +"""GetStats query handler — public node statistics (record count).""" + +from osa.domain.record.service.record import RecordService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.query import Query, QueryHandler, Result + + +class GetStats(Query): + pass + + +class StatsResult(Result): + records: int + + +class GetStatsHandler(QueryHandler[GetStats, StatsResult]): + __auth__ = public() + record_service: RecordService + + async def run(self, cmd: GetStats) -> StatsResult: + return StatsResult(records=await self.record_service.count()) diff --git a/server/osa/domain/record/service/record.py b/server/osa/domain/record/service/record.py index e84512b..9c7d1f0 100644 --- a/server/osa/domain/record/service/record.py +++ b/server/osa/domain/record/service/record.py @@ -55,6 +55,10 @@ async def get(self, srn: RecordSRN) -> Record: raise NotFoundError(f"Record not found: {srn}") return record + async def count(self) -> int: + """Total published records on this node.""" + return await self.record_repo.count() + async def _resolve_schema_id(self, convention_srn: ConventionSRN) -> SchemaId: """Resolve a convention to its schema id at publication time.""" convention = await self.convention_repo.get(convention_srn) diff --git a/server/osa/domain/search/__init__.py b/server/osa/domain/search/__init__.py deleted file mode 100644 index cec05f3..0000000 --- a/server/osa/domain/search/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Search domain - provides search capabilities over indexed records.""" diff --git a/server/osa/domain/search/event/__init__.py b/server/osa/domain/search/event/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/model/__init__.py b/server/osa/domain/search/model/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/port/__init__.py b/server/osa/domain/search/port/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/query/__init__.py b/server/osa/domain/search/query/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/service/__init__.py b/server/osa/domain/search/service/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/semantics/model/schema.py b/server/osa/domain/semantics/model/schema.py index 20af362..a139eb5 100644 --- a/server/osa/domain/semantics/model/schema.py +++ b/server/osa/domain/semantics/model/schema.py @@ -1,8 +1,9 @@ from datetime import datetime from osa.domain.semantics.model.value import FieldDefinition -from osa.domain.shared.error import ValidationError +from osa.domain.shared.error import ReservedNameError, ValidationError from osa.domain.shared.model.aggregate import Aggregate +from osa.domain.shared.model.reserved import RESERVED_NAMES from osa.domain.shared.model.srn import SchemaId @@ -15,6 +16,9 @@ class Schema(Aggregate): created_at: datetime def model_post_init(self, __context: object) -> None: + if self.id.id.root in RESERVED_NAMES: + raise ReservedNameError(self.id.id.root, "schema") + if len(self.fields) < 1: raise ValidationError("Schema must have at least one field") diff --git a/server/osa/domain/shared/error.py b/server/osa/domain/shared/error.py index f03951b..1ba718f 100644 --- a/server/osa/domain/shared/error.py +++ b/server/osa/domain/shared/error.py @@ -52,6 +52,26 @@ class ConflictError(DomainError): """Resource already exists or version conflict.""" +class ReservedNameError(DomainError): + """A schema ID or hook/feature name collides with a reserved URL slot. + + Raised at aggregate construction (``Schema``, ``HookDefinition``) when a + name equals one of :data:`osa.domain.shared.model.reserved.RESERVED_NAMES`. + Surfaced as HTTP 400 with the structured ``code`` field. + """ + + def __init__(self, name: str, kind: str) -> None: + from osa.domain.shared.model.reserved import RESERVED_NAMES + + self.name = name + self.kind = kind + super().__init__( + f"The {kind} name '{name}' is reserved for a URL slot under /data/. " + f"Reserved names: {sorted(RESERVED_NAMES)}.", + code="reserved_name", + ) + + class AuthorizationError(DomainError): """User not authorized for this operation.""" diff --git a/server/osa/domain/shared/model/hook.py b/server/osa/domain/shared/model/hook.py index 3691dbc..8c08dee 100644 --- a/server/osa/domain/shared/model/hook.py +++ b/server/osa/domain/shared/model/hook.py @@ -127,6 +127,15 @@ class HookDefinition(ValueObject): runtime: Annotated[OciConfig, Field(discriminator="type")] feature: Annotated[TableFeatureSpec, Field(discriminator="kind")] + def model_post_init(self, __context: object) -> None: + # A hook name becomes a feature-table URL slot under /data/{schema}/. + # Reject names that would shadow a fixed slot (research §6). + from osa.domain.shared.error import ReservedNameError + from osa.domain.shared.model.reserved import RESERVED_NAMES + + if self.name in RESERVED_NAMES: + raise ReservedNameError(self.name, "hook") + def with_memory(self, memory: str) -> "HookDefinition": """Return a copy with a different memory limit.""" new_limits = self.runtime.limits.model_copy(update={"memory": memory}) diff --git a/server/osa/domain/shared/model/ids.py b/server/osa/domain/shared/model/ids.py new file mode 100644 index 0000000..4fb9c4c --- /dev/null +++ b/server/osa/domain/shared/model/ids.py @@ -0,0 +1,55 @@ +"""Central semantic ID types used across the ``/data/`` read surface. + +Per OSA's type-safety convention, semantic identifiers cross module +boundaries as ``NewType`` aliases rather than bare ``str``. ``HookName`` is +re-exported from :mod:`osa.domain.shared.model.hook` (its source of truth, +where the PG-identifier pattern validation lives) so callers have a single +import location for the IDs this surface deals in. +""" + +from __future__ import annotations + +from typing import NewType + +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.hook import HookName +from osa.domain.shared.model.value import ValueObject + +# Bare internal record identifier (UUIDv7 / ULID). Validation of the exact +# charset/length happens at the API boundary; internally it is opaque. +RecordId = NewType("RecordId", str) + + +class RecordRef(ValueObject): + """A record reference: bare internal id plus optional integer version. + + Wire form is the URL segment ``{id}`` or ``{id}@{version}`` — the record + analogue of :class:`~osa.domain.shared.model.srn.SchemaId`'s + ``@``. ``version is None`` means "latest published". + """ + + id: RecordId + version: int | None = None + + @classmethod + def parse(cls, raw: str) -> RecordRef: + """Parse ``{id}`` or ``{id}@{version}``; raises ``ValidationError``.""" + if "@" not in raw: + return cls(id=RecordId(raw)) + id_part, version_part = raw.split("@", 1) + try: + return cls(id=RecordId(id_part), version=int(version_part)) + except ValueError as exc: + raise ValidationError( + f"Invalid record version in {raw!r}; expected an integer.", + field="id", + ) from exc + + def render(self) -> str: + return self.id if self.version is None else f"{self.id}@{self.version}" + + def __str__(self) -> str: + return self.render() + + +__all__ = ["RecordId", "RecordRef", "HookName"] diff --git a/server/osa/domain/shared/model/reserved.py b/server/osa/domain/shared/model/reserved.py new file mode 100644 index 0000000..7281f8c --- /dev/null +++ b/server/osa/domain/shared/model/reserved.py @@ -0,0 +1,16 @@ +"""Reserved names that collide with fixed URL slots under ``/data/``. + +The unified ``/data/`` read surface exposes a small set of fixed path +segments (``/data/records/...``, and the deferred ``/data/datasets`` slot). +A schema ID or hook/feature name equal to one of these would shadow the +fixed route, so they are forbidden at aggregate construction time. + +This module is the single source of truth for the reserved set. Adding a new +reserved slot (e.g. ``events``) later is a one-line change here. +""" + +from __future__ import annotations + +# Lowercase ASCII; only ever grows over the project's lifetime (shrinking +# would silently re-enable a name that a consumer's URL may already depend on). +RESERVED_NAMES: frozenset[str] = frozenset({"records", "datasets"}) diff --git a/server/osa/domain/shared/query.py b/server/osa/domain/shared/query.py index e896edb..c04d98f 100644 --- a/server/osa/domain/shared/query.py +++ b/server/osa/domain/shared/query.py @@ -21,7 +21,10 @@ class Result(BaseModel): ... C = TypeVar("C", bound=Query) -R = TypeVar("R", bound=Result) +# Unbound: query results may be Result DTOs, domain read models (e.g. a +# catalog/manifest that IS the wire shape), or streaming reads. Result remains +# the conventional base for handler-specific DTOs. +R = TypeVar("R") # Unbound async handler method: (self, cmd) -> Coroutine -> Result _HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] diff --git a/server/osa/domain/export/command/__init__.py b/server/osa/infrastructure/data/__init__.py similarity index 100% rename from server/osa/domain/export/command/__init__.py rename to server/osa/infrastructure/data/__init__.py diff --git a/server/osa/infrastructure/data/postgres_catalog_read_store.py b/server/osa/infrastructure/data/postgres_catalog_read_store.py new file mode 100644 index 0000000..6ea0ed8 --- /dev/null +++ b/server/osa/infrastructure/data/postgres_catalog_read_store.py @@ -0,0 +1,225 @@ +"""Postgres adapter for the ``DataCatalogReadStore`` port. + +Catalog, manifest, latest-schema resolution, and single-record-by-id — the +non-streaming reads behind ``GET /data``, ``GET /data/{schema}``, and +``GET /data/records/{id}``. Table streaming lives in +:class:`~osa.infrastructure.data.postgres_table_read_store.PostgresTableReadStore`. +""" + +from __future__ import annotations + +import logging + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.data.model.catalog import ( + CatalogEntry, + NodeCatalog, + TableResourceSummary, +) +from osa.domain.data.model.manifest import ( + IMPLICIT_FEATURE_COLUMN_SPECS, + IMPLICIT_RECORD_COLUMN_SPECS, + ColumnSpec, + FieldSpec, + SchemaManifest, + TableResource, +) +from osa.domain.data.model.query_plan import TableKind +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.data.schema_feature_reader import SchemaFeatureReader +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, +) +from osa.infrastructure.persistence.tables import records_table, schemas_table + +logger = logging.getLogger(__name__) + +# A feature column's JSON-primitive type → the manifest's semantic FieldType. +_JSON_TYPE_TO_FIELD_TYPE: dict[str, FieldType] = { + "string": FieldType.TEXT, + "number": FieldType.NUMBER, + "integer": FieldType.NUMBER, + "boolean": FieldType.BOOLEAN, + "array": FieldType.TEXT, + "object": FieldType.TEXT, +} + +# All URL-exposed format suffixes (mirrors the route-layer FORMATS registry). +_ALL_FORMATS = ["", "csv", "csv.gz"] + + +class PostgresCatalogReadStore: + def __init__(self, session: AsyncSession, node_domain: Domain) -> None: + self.session = session + # Only the node's DNS domain is needed (to render SRNs in the catalog / + # manifest) — not the whole Config. + self.node_domain = node_domain + self._features = SchemaFeatureReader(session) + + @staticmethod + def _escape_like(value: str) -> str: + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + # ------------------------------------------------------------------ # + # Single record by ID + # ------------------------------------------------------------------ # + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: + # The records PK is the SRN ``urn:osa:{domain}:rec:{id}@{version}``. + # Match the id segment; resolve version (pin or latest published). + pattern = f"urn:osa:%:rec:{self._escape_like(str(id))}@%" + t = records_table + stmt = ( + select(t.c.srn, t.c.schema_id, t.c.schema_version, t.c.published_at, t.c.metadata) + .where(t.c.srn.like(pattern, escape="\\")) + .order_by(t.c.published_at.desc()) + ) + result = await self.session.execute(stmt) + rows = result.mappings().all() + if not rows: + return None + + chosen = None + for row in rows: + srn = RecordSRN.parse(row["srn"]) + if srn.id.root != str(id): + continue + if version is None: + chosen = (srn, row) + break + if int(srn.version.root) == version: + chosen = (srn, row) + break + if chosen is None: + return None + srn, row = chosen + return RecordSummary( + id=RecordId(srn.id.root), + srn=srn, + schema_id=SchemaId.parse(f"{row['schema_id']}@{row['schema_version']}"), + version=int(srn.version.root), + metadata=row["metadata"] or {}, + created_at=row["published_at"], + ) + + # ------------------------------------------------------------------ # + # Catalog & manifest + # ------------------------------------------------------------------ # + + async def get_node_catalog(self) -> NodeCatalog: + stmt = select(schemas_table.c.id, schemas_table.c.version) + result = await self.session.execute(stmt) + schema_rows = [(row["id"], row["version"]) for row in result.mappings()] + entries: list[CatalogEntry] = [] + for short_id, version in schema_rows: + schema_id = SchemaId.parse(f"{short_id}@{version}") + resources = [TableResourceSummary(name="records", kind=TableKind.RECORDS)] + for hook_name, _ in await self._features.feature_tables(schema_id): + resources.append(TableResourceSummary(name=hook_name, kind=TableKind.FEATURE)) + entries.append( + CatalogEntry( + id=short_id, + version=version, + srn=schema_id.to_srn(self.node_domain).render(), + table_resources=resources, + ) + ) + return NodeCatalog(node_domain=self.node_domain.root, schemas=entries) + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | None: + stmt = select(schemas_table.c.fields).where( + schemas_table.c.id == schema_id.id.root, + schemas_table.c.version == schema_id.version.root, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + + field_specs: list[FieldSpec] = [] + column_specs: list[ColumnSpec] = [] + for f in row["fields"]: + ftype = FieldType(f["type"]) + field_specs.append( + FieldSpec( + name=f["name"], + type=ftype, + ontology_id=f.get("ontology_id"), + ontology_version=f.get("ontology_version"), + ) + ) + column_specs.append(ColumnSpec(name=f["name"], type=ftype)) + + record_count = await self._records_count(schema_id) + records_resource = TableResource( + name="records", + kind=TableKind.RECORDS, + # Implicit columns (id, srn, schema_id, version, created_at) precede + # the schema's declared metadata fields — this is the CSV header order. + columns=[*IMPLICIT_RECORD_COLUMN_SPECS, *column_specs], + row_count=record_count, + formats=list(_ALL_FORMATS), + ) + feature_resources = await self._feature_resources(schema_id) + return SchemaManifest( + id=schema_id.id.root, + version=schema_id.version.root, + srn=schema_id.to_srn(self.node_domain).render(), + fields=field_specs, + table_resources=[records_resource, *feature_resources], + ) + + async def _feature_resources(self, schema_id: SchemaId) -> list[TableResource]: + """Build a TableResource for each feature table registered on the schema.""" + resources: list[TableResource] = [] + for hook_name, fschema in await self._features.feature_tables(schema_id): + ft = build_feature_table(hook_name, fschema) + count = await self._features.count_rows(ft, schema_id) + resources.append( + TableResource( + name=hook_name, + kind=TableKind.FEATURE, + # Implicit columns (id, record_srn, created_at) precede the + # hook's declared data columns — this is the CSV header order. + columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *self._feature_column_specs(fschema)], + row_count=count, + formats=list(_ALL_FORMATS), + ) + ) + return resources + + async def get_latest_schema_id(self, schema_short_id: str) -> SchemaId | None: + stmt = select(schemas_table.c.version).where(schemas_table.c.id == schema_short_id) + result = await self.session.execute(stmt) + versions = [row[0] for row in result.all()] + if not versions: + return None + # Pick the highest SemVer (string sort is wrong for e.g. 1.10.0 vs 1.9.0). + latest = max(versions, key=lambda v: tuple(int(p) for p in v.split("-")[0].split("."))) + return SchemaId.parse(f"{schema_short_id}@{latest}") + + async def _records_count(self, schema_id: SchemaId) -> int: + t = records_table + stmt = ( + select(func.count()) + .select_from(t) + .where( + t.c.schema_id == schema_id.id.root, + t.c.schema_version == schema_id.version.root, + ) + ) + return int((await self.session.execute(stmt)).scalar_one()) + + @staticmethod + def _feature_column_specs(fschema: FeatureSchema) -> list[ColumnSpec]: + """Map a feature table's declared columns to manifest ColumnSpecs.""" + return [ + ColumnSpec(name=c.name, type=_JSON_TYPE_TO_FIELD_TYPE[c.json_type]) + for c in fschema.columns + ] diff --git a/server/osa/infrastructure/data/postgres_table_read_store.py b/server/osa/infrastructure/data/postgres_table_read_store.py new file mode 100644 index 0000000..88b7568 --- /dev/null +++ b/server/osa/infrastructure/data/postgres_table_read_store.py @@ -0,0 +1,455 @@ +"""Postgres adapter for the ``DataTableReadStore`` port (streaming reads). + +The streaming primitive (:meth:`stream_rows`) builds the records / feature +SELECT, then iterates it through ``AsyncSession.stream()`` — a server-side +cursor (research §2). Rows are yielded one at a time as flattened column→value +mappings, so memory stays bounded regardless of result size. The try/finally +around the streaming result closes the cursor on client disconnect, returning +the connection to the pool. + +Catalog/manifest/record-by-id reads live in +:class:`~osa.infrastructure.data.postgres_catalog_read_store.PostgresCatalogReadStore`. +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator, Mapping +from datetime import date, datetime, timedelta +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import RowMapping, String, and_, cast, false, func, not_, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.data.model.filter import ( + And, + FeatureFieldRef, + FilterExpr, + FilterOperator, + MetadataFieldRef, + Not, + Or, + Predicate, +) +from osa.domain.data.model.query_plan import ( + QueryPlan, + SortDirection, + TableKind, + decode_cursor, +) +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.shared.error import NotFoundError, ValidationError +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import RecordSRN, SchemaId +from osa.infrastructure.data.schema_feature_reader import SchemaFeatureReader +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, + data_columns, +) +from osa.infrastructure.persistence.keyset import KeysetPage, SortKey +from osa.infrastructure.persistence.metadata_table import ( + MetadataSchema, + build_metadata_table, +) +from osa.infrastructure.persistence.tables import ( + metadata_tables_table, + records_table, +) + +logger = logging.getLogger(__name__) + + +class PostgresTableReadStore: + def __init__(self, session: AsyncSession) -> None: + self.session = session + self._features = SchemaFeatureReader(session) + + @staticmethod + def statement_timeout_sql(timeout: timedelta) -> str: + """Render the caller's execution budget as a ``SET LOCAL`` statement. + + Integer milliseconds, so no operator- or caller-supplied string ever + reaches raw SQL. ``SET LOCAL`` reverts on commit/rollback, so the value + never leaks to a later request reusing the pooled connection. + """ + return f"SET LOCAL statement_timeout = '{int(timeout.total_seconds() * 1000)}ms'" + + @staticmethod + def _escape_like(value: str) -> str: + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + @staticmethod + def _invalid_cursor(exc: Exception) -> ValidationError: + """Map a cursor decode/coerce ``ValueError`` to a 400, not a 500. + + Decoding a corrupt cursor (bad base64, missing ``s``/``id`` keys) or + coercing a non-conforming value (e.g. a non-integer feature id) raises + ``ValueError``, which is not an ``OSAError`` and would otherwise surface + as a generic 500. The cursor is opaque client-supplied input, so a + malformed one is a 400. + """ + return ValidationError( + f"Malformed pagination cursor: {exc}", field="cursor", code="invalid_cursor" + ) + + # ------------------------------------------------------------------ # + # Streaming primitive + # ------------------------------------------------------------------ # + + async def stream_rows( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + # The generator body runs on the route's pre-flight ``__anext__`` pull, + # so the timeout is in place before the SELECT opens its cursor. + if timeout is not None: + await self.session.execute(sa.text(self.statement_timeout_sql(timeout))) + if plan.table_kind == TableKind.RECORDS: + async for row in self._stream_records(plan): + yield row + else: # TableKind.FEATURE + async for row in self._stream_features(plan): + yield row + + async def _stream_records(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + catalog = await self._metadata_catalog_for(plan.schema_id) + if catalog is None: + raise NotFoundError( + f"Schema not found: {plan.schema_id.render()}. See /api/v1/data for the catalog." + ) + metadata_schema = MetadataSchema.model_validate(catalog["metadata_schema"]) + metadata_table = build_metadata_table(catalog["pg_table"], metadata_schema) + + t = records_table + conditions: list[Any] = [ + t.c.schema_id == plan.schema_id.id.root, + t.c.schema_version == plan.schema_id.version.root, + ] + if plan.filter is not None: + conditions.append(self._compile_filter(plan.filter, metadata_t=metadata_table)) + + order_keys, cursor_after = self._records_sort(plan, metadata_table) + if cursor_after is not None: + conditions.append(cursor_after) + + select_cols = [ + t.c.srn, + t.c.schema_id, + t.c.schema_version, + t.c.published_at, + ] + [metadata_table.c[c.name].label(c.name) for c in metadata_schema.columns] + stmt = ( + select(*select_cols) + .select_from(t.join(metadata_table, metadata_table.c.record_srn == t.c.srn)) + .where(and_(*conditions)) + .order_by(*order_keys) + ) + + col_names = [c.name for c in metadata_schema.columns] + # ``stream()`` opens a server-side cursor. The try/finally closes it on + # client disconnect (the generator is thrown a CancelledError), returning + # the connection to the pool (research §2). + result = await self.session.stream(stmt) + try: + async for row in result.mappings(): + yield self._records_row_to_mapping(row, col_names) + finally: + await result.close() + + def _records_row_to_mapping(self, row: RowMapping, col_names: list[str]) -> dict[str, Any]: + srn = RecordSRN.parse(row["srn"]) + summary = RecordSummary( + id=RecordId(srn.id.root), + srn=srn, + schema_id=SchemaId.parse(f"{row['schema_id']}@{row['schema_version']}"), + version=int(srn.version.root), + metadata={name: row[name] for name in col_names if name in row}, + created_at=row["published_at"], + ) + return summary.flatten() + + def _records_sort(self, plan: QueryPlan, metadata_table: Any) -> tuple[list[Any], Any | None]: + t = records_table + # plan.keyset owns tiebreak choice and sort=id aliasing; this method + # only maps its column names onto SQL expressions. + keyset = plan.keyset + is_desc = plan.sort[0].direction == SortDirection.DESC + tiebreak_expr = t.c[keyset.tiebreak_column] + if keyset.sort_column == keyset.tiebreak_column: + sort_expr: sa.ColumnElement[Any] = tiebreak_expr + elif keyset.sort_column in ("created_at", "published_at"): + sort_expr = t.c.published_at + elif keyset.sort_column in metadata_table.c: + sort_expr = metadata_table.c[keyset.sort_column] + else: + raise ValidationError( + f"Unknown sort column '{keyset.sort_column}'.", + field="sort", + code="unknown_sort_field", + ) + page = KeysetPage( + [ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(tiebreak_expr, descending=is_desc), + ] + ) + return page.order_by(), self._cursor_after(plan, page, sort_expr, tiebreak_expr) + + def _cursor_after( + self, + plan: QueryPlan, + page: KeysetPage, + sort_expr: sa.ColumnElement[Any], + id_expr: sa.ColumnElement[Any], + ) -> Any | None: + """Build the keyset ``after`` condition from the plan's opaque cursor. + + Shared by the records and feature sorts: both cursor components are + coerced to their column's Python type; a non-conforming value raises + ``ValueError`` → 400 (the cursor is client-supplied input). + """ + if plan.pagination.cursor is None: + return None + try: + decoded = decode_cursor(str(plan.pagination.cursor)) + return page.after( + ( + self._coerce_cursor_value(decoded["s"], sort_expr), + self._coerce_cursor_value(decoded["id"], id_expr), + ) + ) + except ValueError as exc: + raise self._invalid_cursor(exc) from exc + + @staticmethod + def _coerce_cursor_value(value: Any, sort_expr: sa.ColumnElement[Any]) -> Any: + """Coerce a decoded cursor value to the bound column's Python type. + + Cursors carry date/datetime values as ISO strings (``encode_cursor`` + renders with ``default=str``), but asyncpg rejects mistyped binds — + str vs DATE / TIMESTAMP (including dynamic metadata columns of + FieldType.DATE), str vs BIGINT — so dispatch on the column type, not + the column name. + """ + if value is None: + return None + col_type = sort_expr.type + if isinstance(col_type, sa.DateTime) and isinstance(value, str): + return datetime.fromisoformat(value) + if isinstance(col_type, sa.Date) and isinstance(value, str): + return date.fromisoformat(value) + if isinstance(col_type, sa.Integer): + return int(value) + return value + + # ------------------------------------------------------------------ # + # Feature-table streaming (US5) + # ------------------------------------------------------------------ # + + async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + if plan.feature_name is None: # guarded by QueryPlan, narrowed for the type checker + raise ValidationError("feature_name is required for a FEATURE plan", field="feature") + ft, _ = await self._resolve_feature_table(plan.schema_id, plan.feature_name) + + conditions: list[Any] = [] + if plan.filter is not None: + conditions.append( + self._compile_feature_filter(plan.filter, ft=ft, feature_name=plan.feature_name) + ) + + order_keys, cursor_after = self._features_sort(plan, ft) + if cursor_after is not None: + conditions.append(cursor_after) + + # A features. table is shared by every convention that registers + # the hook name, across schemas — scope to the requested schema's + # records via the records join. + conditions.extend(self._features.records_scope(plan.schema_id)) + + # Implicit columns (id, record_srn, created_at) precede the hook's + # declared data columns — this is the CSV header order. + select_cols = [ft.c.id, ft.c.record_srn, ft.c.created_at, *data_columns(ft)] + stmt = ( + select(*select_cols) + .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) + .where(and_(*conditions)) + .order_by(*order_keys) + ) + + result = await self.session.stream(stmt) + try: + async for row in result.mappings(): + yield dict(row) + finally: + await result.close() + + async def _resolve_feature_table( + self, schema_id: SchemaId, feature_name: str + ) -> tuple[sa.Table, FeatureSchema]: + """Resolve a feature table that belongs to ``schema_id``. + + A feature (hook) belongs to a schema via the convention that registers + it. Streaming a feature not registered on the schema is a 404, not a + leak of another schema's table. + """ + for hook_name, fschema in await self._features.feature_tables(schema_id): + if hook_name == feature_name: + return build_feature_table(hook_name, fschema), fschema + raise NotFoundError( + f"No feature table '{feature_name}' on schema {schema_id.render()}. " + f"See /api/v1/data/{schema_id.render()} for its table resources.", + code="feature_not_found", + ) + + def _features_sort(self, plan: QueryPlan, ft: sa.Table) -> tuple[list[Any], Any | None]: + # plan.keyset owns tiebreak choice and sort=id aliasing; this method + # only maps its column names onto SQL expressions. + keyset = plan.keyset + is_desc = plan.sort[0].direction == SortDirection.DESC + tiebreak_expr = ft.c[keyset.tiebreak_column] + if keyset.sort_column == keyset.tiebreak_column: + sort_expr: sa.ColumnElement[Any] = tiebreak_expr + elif keyset.sort_column in ft.c: + sort_expr = ft.c[keyset.sort_column] + else: + raise ValidationError( + f"Unknown sort column '{keyset.sort_column}'.", + field="sort", + code="unknown_sort_field", + ) + page = KeysetPage( + [ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(tiebreak_expr, descending=is_desc), + ] + ) + return page.order_by(), self._cursor_after(plan, page, sort_expr, tiebreak_expr) + + def _compile_feature_filter(self, expr: FilterExpr, *, ft: sa.Table, feature_name: str) -> Any: + if isinstance(expr, Predicate): + return self._compile_feature_predicate(expr, ft=ft, feature_name=feature_name) + if isinstance(expr, And): + return and_( + *[ + self._compile_feature_filter(op, ft=ft, feature_name=feature_name) + for op in expr.operands + ] + ) + if isinstance(expr, Or): + return or_( + *[ + self._compile_feature_filter(op, ft=ft, feature_name=feature_name) + for op in expr.operands + ] + ) + if isinstance(expr, Not): + inner = self._compile_feature_filter(expr.operand, ft=ft, feature_name=feature_name) + return not_(func.coalesce(inner, false())) + raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") + + def _compile_feature_predicate( + self, predicate: Predicate, *, ft: sa.Table, feature_name: str + ) -> Any: + if isinstance(predicate.field, FeatureFieldRef): + if predicate.field.hook != feature_name: + raise ValidationError( + f"Cross-feature predicates are not supported on the " + f"'{feature_name}' stream (referenced '{predicate.field.hook}').", + field=predicate.field.dotted(), + code="cross_feature_predicate_unsupported", + ) + if predicate.field.column not in ft.c: + raise ValidationError( + f"Unknown feature column '{predicate.field.column}'.", + field=predicate.field.dotted(), + code="unknown_feature_column", + ) + return self._apply_scalar_op( + ft.c[predicate.field.column], predicate.op, predicate.value + ) + if isinstance(predicate.field, MetadataFieldRef): + raise ValidationError( + "Metadata-field predicates are not supported on a feature stream.", + field=predicate.field.dotted(), + code="metadata_predicate_unsupported", + ) + raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") + + # ------------------------------------------------------------------ # + # Records filter compilation + # ------------------------------------------------------------------ # + + async def _metadata_catalog_for(self, schema_id: SchemaId) -> dict[str, Any] | None: + stmt = select(metadata_tables_table).where( + metadata_tables_table.c.schema_id == schema_id.id.root, + metadata_tables_table.c.schema_major == schema_id.major, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return dict(row) if row is not None else None + + def _compile_filter(self, expr: FilterExpr, *, metadata_t: Any) -> Any: + if isinstance(expr, Predicate): + return self._compile_predicate(expr, metadata_t=metadata_t) + if isinstance(expr, And): + return and_(*[self._compile_filter(op, metadata_t=metadata_t) for op in expr.operands]) + if isinstance(expr, Or): + return or_(*[self._compile_filter(op, metadata_t=metadata_t) for op in expr.operands]) + if isinstance(expr, Not): + inner = self._compile_filter(expr.operand, metadata_t=metadata_t) + # NULL → FALSE before negating so records with NULL metadata survive NOT. + return not_(func.coalesce(inner, false())) + raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") + + def _compile_predicate(self, predicate: Predicate, *, metadata_t: Any) -> Any: + if isinstance(predicate.field, MetadataFieldRef): + if predicate.field.field not in metadata_t.c: + raise ValidationError( + f"Unknown metadata field '{predicate.field.field}'. " + "Filterable fields are listed in the schema manifest.", + field=predicate.field.dotted(), + code="unknown_metadata_field", + ) + col = metadata_t.c[predicate.field.field] + return self._apply_scalar_op(col, predicate.op, predicate.value) + if isinstance(predicate.field, FeatureFieldRef): + # Feature predicates in the records stream are a US5 extension. + raise ValidationError( + "Feature-field predicates are not yet supported on the records stream.", + field=predicate.field.dotted(), + code="feature_predicate_unsupported", + ) + raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") + + @staticmethod + def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: + if op == FilterOperator.EQ: + return col == value + if op == FilterOperator.NEQ: + return or_(col != value, col.is_(None)) + if op == FilterOperator.GT: + return col > value + if op == FilterOperator.GTE: + return col >= value + if op == FilterOperator.LT: + return col < value + if op == FilterOperator.LTE: + return col <= value + if op == FilterOperator.IN: + if not isinstance(value, list): + raise ValidationError( + "Operator 'in' requires a list value.", + field=col.key, + code="invalid_value_for_op", + ) + return col.in_(value) + if op == FilterOperator.CONTAINS: + return cast(col, String).ilike( + f"%{PostgresTableReadStore._escape_like(str(value))}%", escape="\\" + ) + if op == FilterOperator.IS_NULL: + return col.is_(None) + raise ValidationError( + f"Unsupported operator: {op}", field="filter", code="unsupported_operator" + ) diff --git a/server/osa/infrastructure/data/schema_feature_reader.py b/server/osa/infrastructure/data/schema_feature_reader.py new file mode 100644 index 0000000..5280f65 --- /dev/null +++ b/server/osa/infrastructure/data/schema_feature_reader.py @@ -0,0 +1,73 @@ +"""Reads which feature tables belong to a schema (via its conventions). + +A ``features.`` table is global (UNIQUE(hook_name)) and shared by every +convention that registers the hook name, across schemas. This reader answers +"which feature tables does this schema expose" (through its conventions) and +counts a feature table's rows scoped to one schema's records. Composed by both +``/data/`` read adapters (table streaming + catalog/manifest). +""" + +from __future__ import annotations + +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import and_, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.persistence.feature_table import FeatureSchema +from osa.infrastructure.persistence.tables import ( + conventions_table, + feature_tables_table, + records_table, +) + + +class SchemaFeatureReader: + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def feature_tables(self, schema_id: SchemaId) -> list[tuple[str, FeatureSchema]]: + """(hook_name, FeatureSchema) for every materialized feature table on the schema.""" + hook_names = await self._hook_names(schema_id) + if not hook_names: + return [] + stmt = select( + feature_tables_table.c.hook_name, feature_tables_table.c.feature_schema + ).where(feature_tables_table.c.hook_name.in_(hook_names)) + result = await self.session.execute(stmt) + return [ + (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) + for row in result.mappings() + ] + + async def count_rows(self, ft: sa.Table, schema_id: SchemaId) -> int: + """Row count of a feature table scoped to the schema's records.""" + stmt = ( + select(func.count()) + .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) + .where(and_(*self.records_scope(schema_id))) + ) + return int((await self.session.execute(stmt)).scalar_one()) + + @staticmethod + def records_scope(schema_id: SchemaId) -> list[Any]: + """Records-join conditions scoping a shared feature table to one schema.""" + return [ + records_table.c.schema_id == schema_id.id.root, + records_table.c.schema_version == schema_id.version.root, + ] + + async def _hook_names(self, schema_id: SchemaId) -> set[str]: + """Hook names registered on the schema's conventions (the schema→feature link).""" + stmt = select(conventions_table.c.hooks).where( + conventions_table.c.schema_id == schema_id.id.root, + conventions_table.c.schema_version == schema_id.version.root, + ) + result = await self.session.execute(stmt) + names: set[str] = set() + for (hooks,) in result.all(): + for hook in hooks or []: + names.add(hook["name"]) + return names diff --git a/server/osa/infrastructure/index/__init__.py b/server/osa/infrastructure/index/__init__.py deleted file mode 100644 index 70e4b1a..0000000 --- a/server/osa/infrastructure/index/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Index infrastructure - storage backend implementations.""" diff --git a/server/osa/infrastructure/index/di.py b/server/osa/infrastructure/index/di.py deleted file mode 100644 index 787426d..0000000 --- a/server/osa/infrastructure/index/di.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Dependency injection provider for index backends.""" - -from dishka import Provider, provide - -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.index.service import IndexService -from osa.util.di.scope import Scope - - -class IndexProvider(Provider): - """Provides configured index backends.""" - - @provide(scope=Scope.APP) - def get_backends(self) -> IndexRegistry: - """Build empty index registry (indexes are dormant).""" - return IndexRegistry({}) - - @provide(scope=Scope.UOW) - def get_index_service(self, indexes: IndexRegistry) -> IndexService: - """Provide IndexService for UOW scope.""" - return IndexService(indexes=indexes) diff --git a/server/osa/infrastructure/persistence/__init__.py b/server/osa/infrastructure/persistence/__init__.py index f2493b6..22af89b 100644 --- a/server/osa/infrastructure/persistence/__init__.py +++ b/server/osa/infrastructure/persistence/__init__.py @@ -1,3 +1,8 @@ -from .di import PersistenceProvider +"""Persistence adapters package. -__all__ = ["PersistenceProvider"] +Intentionally does not re-export ``PersistenceProvider`` at the package root: +importing a leaf module (e.g. ``persistence.feature_table``) must not pull in +the whole DI graph, which imports adapters that in turn import persistence leaf +modules — a circular import. Import the provider from its module directly: +``from osa.infrastructure.persistence.di import PersistenceProvider``. +""" diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py deleted file mode 100644 index be9d067..0000000 --- a/server/osa/infrastructure/persistence/adapter/discovery.py +++ /dev/null @@ -1,701 +0,0 @@ -"""Infrastructure adapters for the discovery domain — read-only SQL queries.""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from datetime import date, datetime -from typing import Any - -from sqlalchemy import ( - Date, - Float, - String, - and_, - cast, - false, - func, - literal, - not_, - or_, - select, - true, - union_all, -) -from sqlalchemy.ext.asyncio import AsyncSession - -from osa.domain.discovery.model.refs import FeatureFieldRef, MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - ColumnInfo, - FeatureCatalogEntry, - FeatureRow, - FilterExpr, - FilterOperator, - Not, - Or, - Predicate, - RecordSummary, - SortOrder, -) -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.hook import ColumnDef -from osa.domain.shared.model.srn import ConventionSRN, RecordSRN, SchemaId -from osa.infrastructure.persistence.feature_table import ( - FeatureSchema, - build_feature_table, - data_columns, -) -from osa.infrastructure.persistence.keyset import KeysetPage, SortKey -from osa.infrastructure.persistence.metadata_table import ( - MetadataSchema, - build_metadata_table, -) -from osa.infrastructure.persistence.tables import ( - feature_tables_table, - metadata_tables_table, - records_table, - schemas_table, -) - -logger = logging.getLogger(__name__) - - -def _escape_like(value: str) -> str: - """Escape LIKE metacharacters so user input is matched literally.""" - return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - - -# Cursor-value coercers — cursor payloads round-trip through base64 JSON as -# plain strings/numbers, but keyset predicates compare against typed columns. -# Without this, ``published_at < 'iso-string'::VARCHAR`` fails on Postgres. - -CursorCoercer = Callable[[Any], Any] - - -def _coerce_identity(value: Any) -> Any: - return value - - -def _coerce_datetime(value: Any) -> Any: - if isinstance(value, str): - return datetime.fromisoformat(value) - return value - - -def _coerce_date(value: Any) -> Any: - if isinstance(value, str): - return date.fromisoformat(value) - return value - - -def _coerce_float(value: Any) -> Any: - return None if value is None else float(value) - - -def _coerce_int(value: Any) -> Any: - return None if value is None else int(value) - - -def _coercer_for_column(col_def: ColumnDef) -> CursorCoercer: - """Pick a coercer matching the Postgres type chosen by ``column_mapper``.""" - if col_def.json_type == "number": - return _coerce_float - if col_def.json_type == "integer": - return _coerce_int - if col_def.json_type == "string": - if col_def.format == "date-time": - return _coerce_datetime - if col_def.format == "date": - return _coerce_date - return _coerce_identity - - -def _to_column_info(columns: list[Any]) -> list[ColumnInfo]: - return [ColumnInfo(name=c.name, type=c.json_type, required=c.required) for c in columns] - - -class PostgresFieldDefinitionReader: - """Builds field name → FieldType maps from registered schemas.""" - - def __init__(self, session: AsyncSession) -> None: - self.session = session - - async def get_all_field_types(self) -> dict[str, FieldType]: - stmt = select(schemas_table.c.fields) - result = await self.session.execute(stmt) - rows = result.mappings().all() - - field_map: dict[str, FieldType] = {} - for row in rows: - for field_def in row["fields"]: - name = field_def["name"] - field_type = FieldType(field_def["type"]) - if name in field_map and field_map[name] != field_type: - raise ValidationError( - f"Conflicting types for field '{name}': " - f"'{field_map[name]}' vs '{field_type}'", - field=name, - ) - field_map[name] = field_type - - return field_map - - async def get_fields_for_schema(self, schema_id: SchemaId) -> dict[str, FieldType]: - stmt = select(schemas_table.c.fields).where( - schemas_table.c.id == schema_id.id.root, - schemas_table.c.version == schema_id.version.root, - ) - result = await self.session.execute(stmt) - row = result.mappings().first() - if row is None: - return {} - return {f["name"]: FieldType(f["type"]) for f in row["fields"]} - - -class PostgresDiscoveryReadStore: - """Compiles FilterExpr trees into SQLAlchemy queries over records / metadata / features.""" - - def __init__(self, session: AsyncSession) -> None: - self.session = session - - async def search_records( - self, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - convention_srn: ConventionSRN | None, - text_fields: list[str], - q: str | None, - sort: str, - order: SortOrder, - cursor: dict[str, Any] | None, - limit: int, - field_types: dict[str, FieldType] | None = None, - ) -> list[RecordSummary]: - t = records_table - ft_map = field_types or {} - - metadata_table = None - metadata_schema: MetadataSchema | None = None - if schema_id is not None: - catalog = await self._metadata_catalog_for(schema_id) - if catalog is not None: - metadata_schema = MetadataSchema.model_validate(catalog["metadata_schema"]) - metadata_table = build_metadata_table(catalog["pg_table"], metadata_schema) - - feature_joins = await self._collect_feature_joins(filter_expr) - - conditions: list[Any] = [] - - if convention_srn is not None: - conditions.append(t.c.convention_srn == str(convention_srn)) - - if filter_expr is not None: - conditions.append( - self._compile_filter_for_records( - filter_expr, - records_t=t, - metadata_t=metadata_table, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - ) - - if q and text_fields and metadata_table is not None and metadata_schema is not None: - pattern = f"%{_escape_like(q)}%" - text_col_names = {c.name for c in metadata_schema.columns if c.json_type == "string"} - text_clauses = [ - cast(metadata_table.c[name], String).ilike(pattern, escape="\\") - for name in text_fields - if name in text_col_names - ] - if text_clauses: - conditions.append(or_(*text_clauses)) - - # Sort expression + matching cursor-value coercer - if sort == "published_at": - sort_expr = t.c.published_at - coerce_cursor: CursorCoercer = _coerce_datetime - elif metadata_table is not None and sort in metadata_table.c: - col = metadata_table.c[sort] - if ft_map.get(sort) == FieldType.NUMBER: - sort_expr = cast(col, Float) - coerce_cursor = _coerce_float - elif ft_map.get(sort) == FieldType.DATE: - sort_expr = cast(col, Date) - coerce_cursor = _coerce_date - else: - sort_expr = col - coerce_cursor = _coerce_identity - else: - sort_expr = t.c.published_at - coerce_cursor = _coerce_datetime - - is_desc = order == SortOrder.DESC - page = KeysetPage( - [ - SortKey(sort_expr, descending=is_desc, nulls_last=True), - SortKey(t.c.srn, descending=is_desc), - ] - ) - order_clauses = page.order_by() - if cursor is not None: - sort_value = coerce_cursor(cursor["s"]) - conditions.append(page.after((sort_value, cursor["id"]))) - - where_clause = and_(*conditions) if conditions else true() - - if metadata_table is not None and metadata_schema is not None: - select_cols = [t.c.srn, t.c.published_at] + [ - metadata_table.c[c.name].label(c.name) for c in metadata_schema.columns - ] - stmt = select(*select_cols).select_from( - t.join(metadata_table, metadata_table.c.record_srn == t.c.srn) - ) - else: - # No schema pinned — project the canonical JSONB metadata column. - # Typed tables are a query-optimized projection; JSONB remains the - # authoritative source for presentation (and for cross-schema - # listings where no single typed table applies). - stmt = select(t.c.srn, t.c.published_at, t.c.metadata) - - for hook, ft in feature_joins.items(): - stmt = stmt.join(ft, ft.c.record_srn == t.c.srn, isouter=True) - - stmt = stmt.where(where_clause).order_by(*order_clauses).limit(limit) - - result = await self.session.execute(stmt) - summaries: list[RecordSummary] = [] - if metadata_table is not None and metadata_schema is not None: - for row in result.mappings(): - meta = {c.name: row[c.name] for c in metadata_schema.columns if c.name in row} - summaries.append( - RecordSummary( - srn=RecordSRN.parse(row["srn"]), - published_at=row["published_at"], - metadata=meta, - ) - ) - else: - for row in result.mappings(): - summaries.append( - RecordSummary( - srn=RecordSRN.parse(row["srn"]), - published_at=row["published_at"], - metadata=row.get("metadata") or {}, - ) - ) - return summaries - - async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: - stmt = select( - feature_tables_table.c.hook_name, - feature_tables_table.c.pg_table, - feature_tables_table.c.feature_schema, - ) - result = await self.session.execute(stmt) - catalog_rows = result.mappings().all() - - if not catalog_rows: - return [] - - parsed = [ - (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) - for row in catalog_rows - ] - - count_parts = [] - for hook_name, schema in parsed: - ft = build_feature_table(hook_name, schema) - count_parts.append( - select( - literal(hook_name).label("hook_name"), - func.count(func.distinct(ft.c.record_srn)).label("cnt"), - ).select_from(ft) - ) - counts_result = await self.session.execute(union_all(*count_parts)) - counts_by_hook = {r["hook_name"]: r["cnt"] for r in counts_result.mappings()} - - return [ - FeatureCatalogEntry( - hook_name=hook_name, - columns=_to_column_info(schema.columns), - record_count=counts_by_hook.get(hook_name, 0), - ) - for hook_name, schema in parsed - ] - - async def get_feature_table_schema(self, hook_name: str) -> FeatureCatalogEntry | None: - stmt = select( - feature_tables_table.c.hook_name, - feature_tables_table.c.feature_schema, - ).where(feature_tables_table.c.hook_name == hook_name) - result = await self.session.execute(stmt) - row = result.mappings().first() - if row is None: - return None - - schema = FeatureSchema.model_validate(row["feature_schema"]) - return FeatureCatalogEntry( - hook_name=row["hook_name"], - columns=_to_column_info(schema.columns), - record_count=0, - ) - - async def search_features( - self, - hook_name: str, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - record_srn: RecordSRN | None, - sort: str, - order: SortOrder, - cursor: dict[str, Any] | None, - limit: int, - ) -> list[FeatureRow]: - pg_table_stmt = select( - feature_tables_table.c.feature_schema, - ).where(feature_tables_table.c.hook_name == hook_name) - pg_result = await self.session.execute(pg_table_stmt) - pg_row = pg_result.mappings().first() - if pg_row is None: - return [] - schema = FeatureSchema.model_validate(pg_row["feature_schema"]) - - ft = build_feature_table(hook_name, schema) - - metadata_table = None - metadata_schema: MetadataSchema | None = None - if schema_id is not None: - catalog = await self._metadata_catalog_for(schema_id) - if catalog is not None: - metadata_schema = MetadataSchema.model_validate(catalog["metadata_schema"]) - metadata_table = build_metadata_table(catalog["pg_table"], metadata_schema) - - feature_joins: dict[str, Any] = {} - if filter_expr is not None: - extra = await self._collect_feature_joins(filter_expr) - for hook, tbl in extra.items(): - if hook != hook_name: - feature_joins[hook] = tbl - - conditions: list[Any] = [] - - if record_srn is not None: - conditions.append(ft.c.record_srn == str(record_srn)) - - if filter_expr is not None: - conditions.append( - self._compile_filter_for_features( - filter_expr, - this_hook=hook_name, - this_ft=ft, - metadata_t=metadata_table, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - ) - - if sort == "id": - sort_expr = ft.c.id - coerce_cursor: CursorCoercer = _coerce_int - else: - sort_expr = ft.c[sort] - col_def = next((c for c in schema.columns if c.name == sort), None) - coerce_cursor = ( - _coercer_for_column(col_def) if col_def is not None else _coerce_identity - ) - - is_desc = order == SortOrder.DESC - page = KeysetPage( - [ - SortKey(sort_expr, descending=is_desc, nulls_last=True), - SortKey(ft.c.id, descending=is_desc), - ] - ) - order_clauses = page.order_by() - if cursor is not None: - sort_value = coerce_cursor(cursor["s"]) - conditions.append(page.after((sort_value, cursor["id"]))) - - where_clause = and_(*conditions) if conditions else true() - - stmt = select(ft.c.id, ft.c.record_srn, *data_columns(ft)) - select_from = ft - if metadata_table is not None: - select_from = select_from.join( - metadata_table, metadata_table.c.record_srn == ft.c.record_srn, isouter=True - ) - for hook, other_ft in feature_joins.items(): - select_from = select_from.join( - other_ft, other_ft.c.record_srn == ft.c.record_srn, isouter=True - ) - stmt = ( - stmt.select_from(select_from).where(where_clause).order_by(*order_clauses).limit(limit) - ) - - result = await self.session.execute(stmt) - feature_rows: list[FeatureRow] = [] - for row in result.mappings(): - row_dict = dict(row) - row_id = row_dict.pop("id") - rsrn = RecordSRN.parse(row_dict.pop("record_srn")) - feature_rows.append(FeatureRow(row_id=row_id, record_srn=rsrn, data=row_dict)) - - return feature_rows - - # ---------------- compilation helpers ---------------- - - async def _metadata_catalog_for(self, schema_id: SchemaId) -> dict[str, Any] | None: - """Look up the metadata table catalog row for a SchemaId.""" - stmt = select(metadata_tables_table).where( - metadata_tables_table.c.schema_id == schema_id.id.root, - metadata_tables_table.c.schema_major == schema_id.major, - ) - result = await self.session.execute(stmt) - row = result.mappings().first() - return dict(row) if row is not None else None - - async def _collect_feature_joins(self, filter_expr: FilterExpr | None) -> dict[str, Any]: - """Build {hook_name: SQLA Table} for every distinct feature ref in the tree.""" - if filter_expr is None: - return {} - hooks: set[str] = set() - for p in _iter_predicates(filter_expr): - if isinstance(p.field, FeatureFieldRef): - hooks.add(p.field.hook) - if not hooks: - return {} - stmt = select( - feature_tables_table.c.hook_name, - feature_tables_table.c.feature_schema, - ).where(feature_tables_table.c.hook_name.in_(hooks)) - result = await self.session.execute(stmt) - joins: dict[str, Any] = {} - for row in result.mappings(): - schema = FeatureSchema.model_validate(row["feature_schema"]) - joins[row["hook_name"]] = build_feature_table(row["hook_name"], schema) - missing = hooks - joins.keys() - if missing: - raise ValidationError( - f"Unknown feature hook(s): {sorted(missing)}", - field="filter", - code="unknown_hook", - ) - return joins - - def _compile_filter_for_records( - self, - expr: FilterExpr, - *, - records_t: Any, - metadata_t: Any, - metadata_schema: MetadataSchema | None, - feature_joins: dict[str, Any], - ) -> Any: - if isinstance(expr, Predicate): - return self._compile_predicate( - expr, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - if isinstance(expr, And): - return and_( - *[ - self._compile_filter_for_records( - op, - records_t=records_t, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Or): - return or_( - *[ - self._compile_filter_for_records( - op, - records_t=records_t, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Not): - inner = self._compile_filter_for_records( - expr.operand, - records_t=records_t, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - # Coalesce NULL → FALSE before negating so records with NULL - # feature/metadata values (including rows missing from outer- - # joined feature tables) survive a NOT predicate. Without this, - # ``NOT (score = 5)`` reads NULL for records with no score and - # three-valued logic silently drops them. - return not_(func.coalesce(inner, false())) - raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") - - def _compile_filter_for_features( - self, - expr: FilterExpr, - *, - this_hook: str, - this_ft: Any, - metadata_t: Any, - metadata_schema: MetadataSchema | None, - feature_joins: dict[str, Any], - ) -> Any: - if isinstance(expr, Predicate): - if isinstance(expr.field, MetadataFieldRef): - if metadata_t is None: - raise ValidationError( - f"Metadata ref {expr.field.dotted()!r} requires schema_id to be set.", - field=expr.field.dotted(), - code="metadata_ref_requires_schema", - ) - col = metadata_t.c[expr.field.field] - return _apply_scalar_op(col, expr.op, expr.value) - if not isinstance(expr.field, FeatureFieldRef): - raise TypeError(f"Unexpected field ref type: {type(expr.field).__name__}") - if expr.field.hook == this_hook: - col = this_ft.c[expr.field.column] - else: - tbl = feature_joins.get(expr.field.hook) - if tbl is None: - raise ValidationError( - f"Unknown feature hook '{expr.field.hook}'.", - field=expr.field.dotted(), - code="unknown_hook", - ) - col = tbl.c[expr.field.column] - return _apply_scalar_op(col, expr.op, expr.value) - if isinstance(expr, And): - return and_( - *[ - self._compile_filter_for_features( - op, - this_hook=this_hook, - this_ft=this_ft, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Or): - return or_( - *[ - self._compile_filter_for_features( - op, - this_hook=this_hook, - this_ft=this_ft, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Not): - inner = self._compile_filter_for_features( - expr.operand, - this_hook=this_hook, - this_ft=this_ft, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - # See ``_compile_filter_for_records`` — NULL → FALSE coalesce so - # NOT over outer-joined feature / optional metadata columns - # includes records with missing values. - return not_(func.coalesce(inner, false())) - raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") - - def _compile_predicate( - self, - predicate: Predicate, - *, - metadata_t: Any, - metadata_schema: MetadataSchema | None, - feature_joins: dict[str, Any], - ) -> Any: - if isinstance(predicate.field, MetadataFieldRef): - if metadata_t is None or metadata_schema is None: - raise ValidationError( - f"Metadata predicate on {predicate.field.dotted()!r} requires " - "the request to pin a 'schema' ('@'). " - "Unscoped metadata filtering is not supported — the typed table " - "is the only filter path.", - field=predicate.field.dotted(), - code="schema_required_for_metadata_query", - ) - col = metadata_t.c[predicate.field.field] - return _apply_scalar_op(col, predicate.op, predicate.value) - - if not isinstance(predicate.field, FeatureFieldRef): - raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") - tbl = feature_joins.get(predicate.field.hook) - if tbl is None: - raise ValidationError( - f"Unknown feature hook '{predicate.field.hook}'.", - field=predicate.field.dotted(), - code="unknown_hook", - ) - col = tbl.c[predicate.field.column] - return _apply_scalar_op(col, predicate.op, predicate.value) - - -def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: - if op == FilterOperator.EQ: - return col == value - if op == FilterOperator.NEQ: - # Feature tables are outer-joined, so a missing feature row makes - # ``col`` NULL. Plain ``col != value`` yields NULL (falsy) and - # silently excludes those records from the result. Users reading - # ``!= X`` expect "anything except X, including missing", so treat - # NULL as non-equal explicitly. - return or_(col != value, col.is_(None)) - if op == FilterOperator.GT: - return col > value - if op == FilterOperator.GTE: - return col >= value - if op == FilterOperator.LT: - return col < value - if op == FilterOperator.LTE: - return col <= value - if op == FilterOperator.IN: - if not isinstance(value, list): - raise ValidationError( - "Operator 'in' requires a list value.", - field=col.key, - code="invalid_value_for_op", - ) - return col.in_(value) - if op == FilterOperator.CONTAINS: - return cast(col, String).ilike(f"%{_escape_like(str(value))}%", escape="\\") - if op == FilterOperator.IS_NULL: - return col.is_(None) - raise ValidationError( - f"Unsupported operator: {op}", field="filter", code="unsupported_operator" - ) - - -def _iter_predicates(expr: FilterExpr): - if isinstance(expr, Predicate): - yield expr - return - if isinstance(expr, Not): - yield from _iter_predicates(expr.operand) - return - if isinstance(expr, (And, Or)): - for op in expr.operands: - yield from _iter_predicates(op) diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 12873e2..1c8eacd 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -16,6 +16,7 @@ from osa.domain.record.port.feature_reader import FeatureReader from osa.domain.record.port.repository import RecordRepository from osa.domain.record.query.get_record import GetRecordHandler +from osa.domain.record.query.get_stats import GetStatsHandler from osa.domain.record.service import RecordService from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader from osa.domain.feature.port.storage import FeatureStoragePort @@ -27,12 +28,12 @@ from osa.domain.shared.port.event_repository import EventRepository from osa.domain.feature.port.feature_store import FeatureStore from osa.domain.validation.port.repository import ValidationRunRepository -from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader -from osa.domain.discovery.port.read_store import DiscoveryReadStore -from osa.infrastructure.persistence.adapter.discovery import ( - PostgresDiscoveryReadStore, - PostgresFieldDefinitionReader, +from osa.domain.data.port.data_read_store import ( + DataCatalogReadStore, + DataTableReadStore, ) +from osa.infrastructure.data.postgres_catalog_read_store import PostgresCatalogReadStore +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore from osa.infrastructure.persistence.adapter.readers import ( OntologyReaderAdapter, SchemaReaderAdapter, @@ -173,13 +174,17 @@ def get_record_service( feature_reader=feature_reader, ) - # Discovery adapters - discovery_read_store = provide( - PostgresDiscoveryReadStore, scope=Scope.UOW, provides=DiscoveryReadStore - ) - field_definition_reader = provide( - PostgresFieldDefinitionReader, scope=Scope.UOW, provides=FieldDefinitionReader - ) + # Data read surface adapter (unified /data/* engine) + @provide(scope=Scope.UOW, provides=DataTableReadStore) + def get_data_table_read_store(self, session: AsyncSession) -> PostgresTableReadStore: + return PostgresTableReadStore(session=session) + + @provide(scope=Scope.UOW, provides=DataCatalogReadStore) + def get_data_catalog_read_store( + self, session: AsyncSession, config: Config + ) -> PostgresCatalogReadStore: + return PostgresCatalogReadStore(session=session, node_domain=Domain(config.domain)) # Record query handlers get_record_handler = provide(GetRecordHandler, scope=Scope.UOW) + get_stats_handler = provide(GetStatsHandler, scope=Scope.UOW) diff --git a/server/osa/sdk/index/__init__.py b/server/osa/sdk/index/__init__.py deleted file mode 100644 index 8c862ca..0000000 --- a/server/osa/sdk/index/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Index SDK - Protocols and types for pluggable storage backends.""" - -from osa.sdk.index.backend import StorageBackend -from osa.sdk.index.config import BackendConfig -from osa.sdk.index.result import QueryResult, SearchHit - -__all__ = [ - "StorageBackend", - "BackendConfig", - "QueryResult", - "SearchHit", -] diff --git a/server/osa/sdk/index/backend.py b/server/osa/sdk/index/backend.py deleted file mode 100644 index 5b8e029..0000000 --- a/server/osa/sdk/index/backend.py +++ /dev/null @@ -1,91 +0,0 @@ -"""StorageBackend protocol for pluggable index backends.""" - -from typing import Any, Protocol - -from osa.sdk.index.result import QueryResult - - -class StorageBackend(Protocol): - """Protocol for pluggable index storage backends. - - Implement this protocol to create custom storage backends - (e.g., vector search, keyword search, graph databases). - """ - - @property - def name(self) -> str: - """Unique name for this index instance.""" - ... - - async def ingest(self, srn: str, record: dict[str, Any]) -> None: - """Store a record in the index. - - Args: - srn: Structured Resource Name identifying the record. - record: The record metadata to index. - """ - ... - - async def delete(self, srn: str) -> None: - """Remove a record from the index. - - Args: - srn: Structured Resource Name of the record to remove. - """ - ... - - async def query(self, q: str, limit: int = 20) -> QueryResult: - """Execute a query and return structured results. - - Args: - q: The query string. - limit: Maximum number of results to return. - - Returns: - QueryResult containing matching hits. - """ - ... - - async def health(self) -> bool: - """Check if the backend is operational. - - Returns: - True if the backend is healthy, False otherwise. - """ - ... - - async def count(self) -> int: - """Return the number of documents in the index. - - Returns: - Number of indexed documents. - """ - ... - - async def flush(self) -> None: - """Flush any buffered operations to storage. - - For backends that buffer writes for batch efficiency (e.g., batch embedding - generation), this method ensures all pending records are persisted. - - Backends without internal buffering can implement this as a no-op. - - Deprecated: With event-level batching, backends should be stateless. - This method will be removed in a future version. - """ - ... - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - """Index a batch of records atomically. - - All records in the batch must succeed or all must fail together. - This method should be optimized for batch operations (e.g., batch - embedding generation, bulk database inserts). - - Args: - records: List of (srn, metadata) tuples to index - - Raises: - Exception: If any record fails to index (none are committed) - """ - ... diff --git a/server/osa/sdk/index/config.py b/server/osa/sdk/index/config.py deleted file mode 100644 index 8a96628..0000000 --- a/server/osa/sdk/index/config.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Base configuration for storage backends.""" - -from pydantic import BaseModel - - -class BackendConfig(BaseModel): - """Base configuration for storage backends. - - Extend this class for vendor-specific configuration. - """ - - pass diff --git a/server/osa/sdk/index/result.py b/server/osa/sdk/index/result.py deleted file mode 100644 index d1bb063..0000000 --- a/server/osa/sdk/index/result.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Type-safe result types for index queries.""" - -from typing import Any - -from pydantic import BaseModel - - -class SearchHit(BaseModel, frozen=True): - """A single search result from an index query.""" - - srn: str - score: float - metadata: dict[str, Any] - - -class QueryResult(BaseModel, frozen=True): - """Result of an index query.""" - - hits: list[SearchHit] - total: int - query: str diff --git a/server/pyproject.toml b/server/pyproject.toml index c9b69e2..c63ca30 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "pyjwt>=2.11.0", "openpyxl>=3.1.5", "python-multipart>=0.0.22", + "slowapi>=0.1.9", ] [project.entry-points."osa.sources"] diff --git a/server/tests/contract/test_data_surface_contract.py b/server/tests/contract/test_data_surface_contract.py new file mode 100644 index 0000000..7b4982c --- /dev/null +++ b/server/tests/contract/test_data_surface_contract.py @@ -0,0 +1,75 @@ +"""DB-free contract tests for the unified ``/data/`` read surface (US1–US6). + +Runs in CI's ``server-contract`` job, which has **no** Postgres. Everything here +is asserted against routing + the OpenAPI surface without touching the database: + +* the legacy read routes (``/discovery``, ``/records/{srn}``, ``/search``) are + gone — Starlette resolves an unregistered path to 404 before any DI/DB runs; +* the new ``/data/`` operation IDs are registered with stable, unique names + (the SDK-codegen contract). + +DB-backed behaviour of the surface lives in ``tests/integration/test_data_*``. +""" + +import os + +import pytest +from httpx import ASGITransport, AsyncClient + +# create_app() reads Config() at call time; provide the env it needs so this +# test is self-sufficient regardless of how the contract job is configured. +os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") +os.environ.setdefault("OSA_AUTH__JWT__SECRET", "test-secret-for-contract-tests-minimum-32-chars") + + +def _app(): + from osa.application.api.rest.app import create_app + + return create_app() + + +@pytest.fixture +def client() -> AsyncClient: + # No lifespan → no worker pool, no DB connection; pure ASGI routing. + return AsyncClient(transport=ASGITransport(app=_app()), base_url="http://test") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "path", + [ + "/api/v1/discovery/records", + "/api/v1/records/urn:osa:localhost:rec:anything@1", + "/api/v1/search/records?q=anything", + ], +) +async def test_legacy_read_routes_are_gone(client: AsyncClient, path: str): + async with client: + resp = await client.get(path) + assert resp.status_code == 404 + + +def test_data_surface_operation_ids_registered(): + # The factory-minted table op IDs must exist and be unique (SDK codegen). + app = _app() + op_ids = [oid for route in app.routes if (oid := getattr(route, "operation_id", None))] + expected = { + "data_get_node_catalog", + "data_get_record_by_id", + "records_get_json", + "records_post_json", + "records_get_csv", + "records_post_csv", + "records_get_csv_gz", + "records_post_csv_gz", + "feature_get_json", + "feature_post_json", + "feature_get_csv", + "feature_post_csv", + "feature_get_csv_gz", + "feature_post_csv_gz", + } + assert expected <= set(op_ids) + # No duplicates among factory-minted IDs. + factory_ids = [o for o in op_ids if o in expected] + assert len(factory_ids) == len(set(factory_ids)) diff --git a/server/tests/contract/test_worker_lifecycle.py b/server/tests/contract/test_worker_lifecycle.py deleted file mode 100644 index 9032429..0000000 --- a/server/tests/contract/test_worker_lifecycle.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Contract tests for WorkerPool lifecycle in FastAPI lifespan. - -Tests for Phase 7: Migration. -""" - -import asyncio -from datetime import UTC, datetime -from typing import Any -from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.keyword_index_handler import KeywordIndexHandler -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.event import ClaimResult, EventId -from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion -from osa.infrastructure.event.worker import WorkerPool - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str): - self._name = name - self.ingested: list[tuple[str, dict]] = [] - - @property - def name(self) -> str: - return self._name - - async def ingest_batch(self, records: list[tuple[str, dict]]) -> None: - self.ingested.extend(records) - - -def make_mock_container(handler_type: type, handler_instance: Any): - """Create a mock DI container that provides scoped dependencies. - - Creates a container mock that returns scoped Outbox and AsyncSession - when called as an async context manager with scope parameter. - """ - from osa.domain.shared.outbox import Outbox - - outbox = AsyncMock() - outbox.claim.return_value = ClaimResult(deliveries=[], claimed_at=datetime.now(UTC)) - outbox.reset_stale_claims.return_value = 0 - - session = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - - async def get_dependency(cls: type) -> Any: - """Return the appropriate dependency based on the requested class.""" - if cls == Outbox: - return outbox - if cls == handler_type: - return handler_instance - return session - - # Create scope that returns dependencies - scope = AsyncMock() - scope.get = AsyncMock(side_effect=get_dependency) - - # Create async context manager - context = MagicMock() - context.__aenter__ = AsyncMock(return_value=scope) - context.__aexit__ = AsyncMock(return_value=None) - - # Container callable returns the context manager - container = MagicMock() - container.return_value = context - - return container, outbox, session - - -class TestWorkerPoolLifecycle: - """Tests for WorkerPool lifecycle management.""" - - @pytest.mark.asyncio - async def test_worker_pool_starts_and_stops_cleanly(self): - """WorkerPool should start and stop without errors.""" - vector_backend = FakeBackend("vector") - keyword_backend = FakeBackend("keyword") - registry = IndexRegistry({"vector": vector_backend, "keyword": keyword_backend}) - - # Create handler instance - vector_handler = VectorIndexHandler(indexes=registry) - - container, outbox, session = make_mock_container(VectorIndexHandler, vector_handler) - - pool = WorkerPool(container=container, stale_claim_interval=60.0) - pool.register(VectorIndexHandler) - pool.register(KeywordIndexHandler) - - # Start - await pool.start() - await asyncio.sleep(0.05) - - # Verify workers are running - assert len(pool.workers) == 2 - for worker in pool.workers: - assert worker._task is not None - assert not worker._task.done() - - # Stop - await pool.stop() - - # Verify workers are stopped - for worker in pool.workers: - assert worker._shutdown is True - - @pytest.mark.asyncio - async def test_worker_pool_as_context_manager(self): - """WorkerPool should work as async context manager.""" - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - vector_handler = VectorIndexHandler(indexes=registry) - - container, outbox, session = make_mock_container(VectorIndexHandler, vector_handler) - - pool = WorkerPool(container=container, stale_claim_interval=60.0) - pool.register(VectorIndexHandler) - - async with pool: - # Workers should be running - assert pool.workers[0]._task is not None - await asyncio.sleep(0.02) - - # After exit, workers should be stopped - assert pool.workers[0]._shutdown is True - - -class TestIndexHandlers: - """Tests for concrete index handlers.""" - - @pytest.mark.asyncio - async def test_vector_handler_processes_batch(self): - """VectorIndexHandler should process IndexRecord events in batches.""" - backend = FakeBackend("vector") - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Create test events - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(f"rec-{i}"), - version=RecordVersion(1), - ), - metadata={"title": f"Record {i}"}, - ) - for i in range(5) - ] - - # Process - await handler.handle_batch(events) - - # Verify backend received all records - assert len(backend.ingested) == 5 - - @pytest.mark.asyncio - async def test_keyword_handler_processes_individually(self): - """KeywordIndexHandler should process IndexRecord events one at a time.""" - backend = FakeBackend("keyword") - registry = IndexRegistry({"keyword": backend}) - handler = KeywordIndexHandler(indexes=registry) - - # batch_size should be 1 - assert KeywordIndexHandler.__batch_size__ == 1 - - # Create test event - event = IndexRecord( - id=EventId(uuid4()), - backend_name="keyword", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId("rec-1"), - version=RecordVersion(1), - ), - metadata={"title": "Record 1"}, - ) - - # Process - await handler.handle(event) - - # Verify - assert len(backend.ingested) == 1 - - @pytest.mark.asyncio - async def test_handler_raises_on_backend_failure(self): - """Handlers should raise when backend fails (Worker handles retry).""" - # Backend that fails - failing_backend = AsyncMock() - failing_backend.name = "vector" - failing_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend error")) - - registry = IndexRegistry({"vector": failing_backend}) - handler = VectorIndexHandler(indexes=registry) - - event = IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId("rec-1"), - version=RecordVersion(1), - ), - metadata={"title": "Record 1"}, - ) - - # Process - should raise (Worker handles retry) - with pytest.raises(Exception, match="Backend error"): - await handler.handle_batch([event]) diff --git a/server/tests/integration/persistence/test_discovery_pagination.py b/server/tests/integration/persistence/test_discovery_pagination.py deleted file mode 100644 index 49b2fb8..0000000 --- a/server/tests/integration/persistence/test_discovery_pagination.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Integration tests for discovery keyset pagination against real Postgres. - -Regression coverage for a production bug where paginating past page 1 raised -``operator does not exist: timestamp with time zone < character varying`` -because the cursor's sort value round-tripped through JSON as a plain string -and was bound as ``VARCHAR`` against the typed ``records.published_at`` column. -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -import pytest -from sqlalchemy.ext.asyncio import AsyncSession - -from osa.domain.discovery.model.value import SortOrder, decode_cursor, encode_cursor -from osa.infrastructure.persistence.adapter.discovery import PostgresDiscoveryReadStore -from osa.infrastructure.persistence.tables import records_table - - -async def _insert_record(session: AsyncSession, srn: str, published_at: datetime) -> None: - await session.execute( - records_table.insert().values( - srn=srn, - convention_srn="urn:osa:localhost:conv:test@1.0.0", - schema_id="test", - schema_version="1.0.0", - source={"type": "test", "id": srn}, - metadata={}, - published_at=published_at, - ) - ) - await session.commit() - - -@pytest.mark.asyncio -class TestDiscoveryPaginationPublishedAt: - async def test_second_page_with_published_at_cursor(self, pg_session: AsyncSession) -> None: - """Fetching page 2 with a cursor must not trip the timestamptz/varchar - mismatch — the bug manifested only on requests that supplied a cursor.""" - store = PostgresDiscoveryReadStore(pg_session) - base = datetime(2026, 4, 7, 9, 0, 0, tzinfo=UTC) - records = [(f"urn:osa:localhost:rec:page-{i}@1", base.replace(second=i)) for i in range(3)] - for srn, ts in records: - await _insert_record(pg_session, srn, ts) - - first_page = await store.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=2, - ) - assert len(first_page) == 2 - - # Encode + decode the cursor the same way the service does — this is the - # round-trip that previously produced a VARCHAR bind. - last = first_page[-1] - cursor_str = encode_cursor(last.published_at.isoformat(), str(last.srn)) - decoded = decode_cursor(cursor_str) - - second_page = await store.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=decoded, - limit=2, - ) - - assert len(second_page) == 1 - returned = {str(r.srn) for r in first_page} | {str(r.srn) for r in second_page} - assert returned == {srn for srn, _ in records} diff --git a/server/osa/domain/export/event/__init__.py b/server/tests/integration/semantics/__init__.py similarity index 100% rename from server/osa/domain/export/event/__init__.py rename to server/tests/integration/semantics/__init__.py diff --git a/server/tests/integration/semantics/test_reserved_name_registration.py b/server/tests/integration/semantics/test_reserved_name_registration.py new file mode 100644 index 0000000..d1d307e --- /dev/null +++ b/server/tests/integration/semantics/test_reserved_name_registration.py @@ -0,0 +1,46 @@ +"""Integration test for reserved-name enforcement at schema registration (T066). + +Exercises the real schema-registration service path against Postgres: a schema +whose id is a reserved ``/data/`` path slot (``records`` / ``datasets``) is +rejected with ``ReservedNameError`` (``code="reserved_name"``), which the API +layer maps to HTTP 400. The aggregate invariant is the single source of truth; +this confirms it fires through the persistence-wired service, not just the unit. + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.semantics.service.schema import SchemaService +from osa.domain.shared.error import ReservedNameError +from osa.domain.shared.model.srn import Domain, SchemaIdentifier +from osa.infrastructure.persistence.repository.ontology import PostgresOntologyRepository +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) + + +def _service(session: AsyncSession) -> SchemaService: + return SchemaService( + schema_repo=PostgresSemanticsSchemaRepository(session), + ontology_repo=PostgresOntologyRepository(session), + node_domain=Domain("localhost"), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("reserved", ["records", "datasets"]) +async def test_registering_reserved_schema_id_rejected( + pg_engine: AsyncEngine, pg_session: AsyncSession, reserved: str +): + service = _service(pg_session) + with pytest.raises(ReservedNameError) as exc: + await service.create_schema( + id=SchemaIdentifier(reserved), + title="should be rejected", + version="1.0.0", + fields=[], + ) + assert exc.value.code == "reserved_name" + assert exc.value.name == reserved diff --git a/server/tests/integration/test_crash_recovery.py b/server/tests/integration/test_crash_recovery.py deleted file mode 100644 index 18ab740..0000000 --- a/server/tests/integration/test_crash_recovery.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Integration tests for crash recovery in event processing. - -Tests that events are not lost when the worker is interrupted mid-processing. -""" - -from typing import Any -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion - - -class TrackingBackend: - """Backend that tracks all operations for verification.""" - - def __init__(self, name: str, fail_at: int | None = None): - self._name = name - self._fail_at = fail_at - self._call_count = 0 - self.indexed_records: list[tuple[str, dict]] = [] - - @property - def name(self) -> str: - return self._name - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - """Track records and optionally fail at specific call.""" - self._call_count += 1 - if self._fail_at is not None and self._call_count >= self._fail_at: - raise Exception("Simulated crash") - self.indexed_records.extend(records) - - -def make_index_record(backend_name: str, record_id: str) -> IndexRecord: - """Create an IndexRecord for testing.""" - return IndexRecord( - id=EventId(uuid4()), - backend_name=backend_name, - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(record_id), - version=RecordVersion(1), - ), - metadata={"id": record_id, "title": f"Record {record_id}"}, - ) - - -class FakeOutbox: - """Simulates outbox with pending events.""" - - def __init__(self, events: list[IndexRecord]): - self.pending = list(events) - self.delivered: list[EventId] = [] - self.failed: dict[EventId, str] = {} - - async def fetch_pending(self, limit: int) -> list[IndexRecord]: - """Return pending events.""" - batch = self.pending[:limit] - return batch - - async def mark_delivered(self, event_id: EventId) -> None: - """Mark event as delivered.""" - self.delivered.append(event_id) - self.pending = [e for e in self.pending if e.id != event_id] - - async def mark_failed(self, event_id: EventId, error: str) -> None: - """Mark event as failed.""" - self.failed[event_id] = error - - -class TestCrashRecoveryScenarios: - """Tests for crash recovery scenarios.""" - - @pytest.mark.asyncio - async def test_events_remain_pending_on_crash(self): - """Events should remain pending if processing crashes before commit.""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(10)] - outbox = FakeOutbox(events) - - # Simulate a backend that fails mid-batch - backend = TrackingBackend("vector", fail_at=1) # Fail on first call - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - Try to process, expecting failure - with pytest.raises(Exception, match="Simulated crash"): - await handler.handle_batch(events) - - # Assert - Events should remain pending (outbox unchanged) - assert len(outbox.pending) == 10 # All still pending - assert len(outbox.delivered) == 0 # None delivered - assert len(backend.indexed_records) == 0 # None indexed - - @pytest.mark.asyncio - async def test_recovery_processes_all_pending_events(self): - """After recovery, all pending events should be processed.""" - # Arrange - Create events and simulate they were fetched but not committed - events = [make_index_record("vector", f"record-{i}") for i in range(10)] - outbox = FakeOutbox(events) - - # First attempt: backend fails - failing_backend = TrackingBackend("vector", fail_at=1) - failing_registry = IndexRegistry({"vector": failing_backend}) - failing_handler = VectorIndexHandler(indexes=failing_registry) - - with pytest.raises(Exception): - await failing_handler.handle_batch(events) - - # Events still pending - assert len(outbox.pending) == 10 - - # Second attempt (recovery): backend works - working_backend = TrackingBackend("vector") - working_registry = IndexRegistry({"vector": working_backend}) - working_handler = VectorIndexHandler(indexes=working_registry) - - # Act - Retry processing - await working_handler.handle_batch(outbox.pending) - - # Assert - All events processed - assert len(working_backend.indexed_records) == 10 - - @pytest.mark.asyncio - async def test_partial_batch_failure_is_atomic(self): - """If batch fails partway, none of the events should be committed.""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(5)] - - # Backend that succeeds on first record but throws exception - class PartialFailBackend: - def __init__(self): - self.name = "vector" - self.processed: list[tuple[str, dict]] = [] - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - # Process some, then fail - for i, (srn, meta) in enumerate(records): - if i >= 2: - raise Exception("Partial failure at record 2") - self.processed.append((srn, meta)) - - backend = PartialFailBackend() - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - with pytest.raises(Exception, match="Partial failure"): - await handler.handle_batch(events) - - # Assert - Some records were processed but batch should be atomic - # In production, the outbox marks all events as failed together - # The handler correctly propagates the error for retry - assert len(backend.processed) == 2 # Backend saw 2 before crash - - @pytest.mark.asyncio - async def test_idempotent_reprocessing(self): - """Reprocessing the same events should be safe (idempotent).""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(5)] - - backend = TrackingBackend("vector") - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - Process twice - await handler.handle_batch(events) - await handler.handle_batch(events) # Reprocess same events - - # Assert - Records should be in backend (upsert semantics handle duplicates) - # The backend receives all records from both batches - assert len(backend.indexed_records) == 10 # 5 + 5 - - # In production, ChromaDB's upsert handles this - same ID overwrites - # The important thing is no data corruption - - @pytest.mark.asyncio - async def test_crash_safe_no_in_memory_buffer(self): - """Verify there's no in-memory buffer that could lose data.""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(3)] - - class StatelessBackend: - """Backend with no internal state.""" - - def __init__(self): - self.name = "vector" - self.batch_calls: list[list[tuple[str, dict]]] = [] - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - # Immediately persist (no buffering) - self.batch_calls.append(list(records)) - - backend = StatelessBackend() - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - await handler.handle_batch(events) - - # Assert - All records in single batch call, immediately persisted - assert len(backend.batch_calls) == 1 - assert len(backend.batch_calls[0]) == 3 - - # No flush needed - data is persisted immediately - # If we crashed right after ingest_batch, all 3 records are safe diff --git a/server/tests/integration/test_data_features_postgres.py b/server/tests/integration/test_data_features_postgres.py new file mode 100644 index 0000000..cc01784 --- /dev/null +++ b/server/tests/integration/test_data_features_postgres.py @@ -0,0 +1,327 @@ +"""Integration tests for PostgresTableReadStore feature-table streaming (US5). + +Exercises the FEATURE branch of the unified ``/data/`` engine against a real +Postgres: a schema with a registered hook produces a ``features.`` table; +the engine streams it, filters on its columns, and surfaces it in the catalog +and manifest scoped to the owning schema. + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +from datetime import UTC, datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.data.model.filter import FilterOperator, FeatureFieldRef, Predicate +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, + encode_cursor, +) +from osa.domain.data.model.query_plan import TableKind as TK +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.data.postgres_catalog_read_store import PostgresCatalogReadStore +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore +from osa.infrastructure.persistence.feature_store import PostgresFeatureStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) +from osa.infrastructure.persistence.tables import conventions_table + +from tests.integration.conftest import seed_record + +SCHEMA = SchemaId.parse("compound@1.0.0") +SCHEMA_B = SchemaId.parse("protein@1.0.0") +HOOK = "chem_features" + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +def _feature_columns() -> list[ColumnDef]: + return [ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ] + + +async def _setup_schema( + engine: AsyncEngine, session: AsyncSession, schema: SchemaId = SCHEMA +) -> PostgresMetadataStore: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(schema, _fields()) + await PostgresSemanticsSchemaRepository(session).save( + Schema(id=schema, title=schema.id.root, fields=_fields(), created_at=datetime.now(UTC)) + ) + return store + + +async def _register_hook( + engine: AsyncEngine, + session: AsyncSession, + hook_name: str = HOOK, + schema: SchemaId = SCHEMA, +) -> None: + """Link the schema → hook via a convention row, then create its feature table. + + The read store only reads ``hooks[*].name`` from the convention, so a + name-only hooks payload is sufficient to scope the feature to the schema. + A pre-existing feature table is reused, mirroring ``CreateFeatureTables`` + (which swallows the ConflictError when two conventions share a hook name). + """ + await session.execute( + conventions_table.insert().values( + srn=f"urn:osa:localhost:conv:{schema.id.root}-{hook_name}@1.0.0", + title=f"{schema.id.root} conv", + description=None, + schema_id=schema.id.root, + schema_version=schema.version.root, + file_requirements={}, + hooks=[{"name": hook_name}], + source=None, + created_at=datetime.now(UTC), + ) + ) + await session.commit() + feature_store = PostgresFeatureStore(engine, session) + try: + await feature_store.create_table(hook_name, _feature_columns()) + except ConflictError: + pass + + +async def _publish( + engine: AsyncEngine, store: PostgresMetadataStore, rid: str, schema: SchemaId = SCHEMA +) -> RecordSRN: + srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") + await seed_record( + engine, + srn=str(srn), + schema_id=schema.id.root, + schema_version=schema.version.root, + metadata={"species": "Homo sapiens"}, + published_at=datetime(2026, 1, 1, tzinfo=UTC), + ) + await store.insert(schema, srn, {"species": "Homo sapiens"}) + return srn + + +def _feature_plan(filter_expr=None, limit=50) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name=HOOK, + filter=filter_expr, + pagination=PaginationParams(limit=limit), + ) + + +async def _drain(store: PostgresTableReadStore, plan: QueryPlan) -> list[dict]: + return [dict(row) async for row in store.stream_rows(plan)] + + +@pytest.mark.asyncio +class TestStreamFeatures: + async def test_streams_feature_rows_with_data_columns( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features( + HOOK, str(srn), [{"score": 0.9, "label": "high"}, {"score": 0.1, "label": "low"}] + ) + + rs = PostgresTableReadStore(pg_session) + rows = await _drain(rs, _feature_plan()) + + assert len(rows) == 2 + # Default FEATURE sort is id asc → first inserted row first. + first = rows[0] + assert first["record_srn"] == str(srn) + assert first["score"] == 0.9 + assert first["label"] == "high" + assert "id" in first and "created_at" in first + + async def test_feature_filter_narrows_results( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features( + HOOK, str(srn), [{"score": 0.9, "label": "high"}, {"score": 0.1, "label": "low"}] + ) + + rs = PostgresTableReadStore(pg_session) + plan = _feature_plan( + filter_expr=Predicate( + field=FeatureFieldRef(hook=HOOK, column="score"), + op=FilterOperator.GTE, + value=0.5, + ) + ) + rows = await _drain(rs, plan) + assert [r["label"] for r in rows] == ["high"] + + async def test_unknown_feature_raises_not_found( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresTableReadStore(pg_session) + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="nonexistent", + ) + with pytest.raises(NotFoundError): + await _drain(rs, plan) + + +@pytest.mark.asyncio +class TestFeatureCreatedAtCursor: + """``?sort=created_at`` pagination on a feature table must round-trip. + + Feature rows are raw DB mappings, so the route layer encodes the next-page + cursor from a Python ``datetime`` (encoder must serialize it), and the + second page binds the decoded value against ``DateTime(timezone=True)`` + (decoder must coerce the ISO string back — created_at is an implicit + column with no ColumnDef in the hook's FeatureSchema).""" + + async def test_created_at_sort_cursor_round_trips( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features(HOOK, str(srn), [{"score": s} for s in (0.1, 0.2, 0.3)]) + + rs = PostgresTableReadStore(pg_session) + sort = [SortSpec(column="created_at", direction=SortDirection.ASC)] + all_rows = await _drain( + rs, + QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name=HOOK, sort=sort), + ) + assert len(all_rows) == 3 + + # Encode the cursor exactly as routes/data/_streaming.py does: straight + # from the row mapping, where created_at is a datetime. + cursor = encode_cursor(all_rows[1]["created_at"], all_rows[1]["id"]) + page2 = await _drain( + rs, + QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name=HOOK, + sort=sort, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ), + ) + assert [r["id"] for r in page2] == [all_rows[2]["id"]] + + +@pytest.mark.asyncio +class TestFeatureSchemaScoping: + """Two schemas whose conventions register the same hook name share one + physical ``features.`` table (``CreateFeatureTables`` swallows the + ConflictError). Reads at ``/data/{schema}/{feature}`` must still be scoped + to records of the requested schema.""" + + async def _seed_shared_hook( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ) -> tuple[RecordSRN, RecordSRN]: + store_a = await _setup_schema(pg_engine, pg_session) + store_b = await _setup_schema(pg_engine, pg_session, schema=SCHEMA_B) + srn_a = await _publish(pg_engine, store_a, "reca") + srn_b = await _publish(pg_engine, store_b, "recb", schema=SCHEMA_B) + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + await _register_hook(pg_engine, pg_session, schema=SCHEMA_B) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features(HOOK, str(srn_a), [{"score": 0.9, "label": "a"}]) + await feature_store.insert_features(HOOK, str(srn_b), [{"score": 0.2, "label": "b"}]) + return srn_a, srn_b + + async def test_stream_excludes_rows_of_other_schemas( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + srn_a, _ = await self._seed_shared_hook(pg_engine, pg_session) + + rs = PostgresTableReadStore(pg_session) + rows = await _drain(rs, _feature_plan()) + + assert [r["record_srn"] for r in rows] == [str(srn_a)] + + async def test_manifest_row_count_scoped_to_schema( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await self._seed_shared_hook(pg_engine, pg_session) + + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + manifest = await rs.get_schema_manifest(SCHEMA) + assert manifest is not None + feature_res = next(t for t in manifest.table_resources if t.name == HOOK) + assert feature_res.row_count == 1 + + +@pytest.mark.asyncio +class TestFeatureCatalogAndManifest: + async def test_manifest_includes_feature_resource( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features(HOOK, str(srn), [{"score": 0.9, "label": "high"}]) + + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + manifest = await rs.get_schema_manifest(SCHEMA) + assert manifest is not None + feature_res = next(t for t in manifest.table_resources if t.name == HOOK) + assert feature_res.kind == TK.FEATURE + assert feature_res.row_count == 1 + col_names = [c.name for c in feature_res.columns] + assert col_names[:3] == ["id", "record_srn", "created_at"] + assert "score" in col_names and "label" in col_names + assert feature_res.formats == ["", "csv", "csv.gz"] + + async def test_catalog_lists_feature_resource( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + catalog = await rs.get_node_catalog() + entry = next(e for e in catalog.schemas if e.id == "compound") + names = {(t.name, t.kind) for t in entry.table_resources} + assert ("records", TK.RECORDS) in names + assert (HOOK, TK.FEATURE) in names diff --git a/server/tests/integration/test_data_read_store_postgres.py b/server/tests/integration/test_data_read_store_postgres.py new file mode 100644 index 0000000..24bc706 --- /dev/null +++ b/server/tests/integration/test_data_read_store_postgres.py @@ -0,0 +1,365 @@ +"""Integration tests for PostgresTableReadStore against a real Postgres. + +Exercises the unified /data/ engine end-to-end: records streaming via the +server-side cursor, filtered streaming, single-record-by-id, node catalog, and +schema manifest. Mirrors the discovery integration tests' seeding pattern +(schema row + dynamic metadata table + seeded records). + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +from datetime import UTC, datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.data.model.filter import FilterOperator, MetadataFieldRef, Predicate +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, + encode_cursor, +) +from osa.domain.data.model.query_plan import TableKind as TK +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.data.postgres_catalog_read_store import PostgresCatalogReadStore +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) + +from tests.integration.conftest import seed_record + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + FieldDefinition( + name="mw", + type=FieldType.NUMBER, + required=False, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +async def _setup_schema(engine: AsyncEngine, session: AsyncSession) -> PostgresMetadataStore: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + repo = PostgresSemanticsSchemaRepository(session) + await repo.save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + return store + + +async def _publish( + engine: AsyncEngine, + store: PostgresMetadataStore, + rid: str, + species: str, + mw: float, + published_at: datetime, +) -> None: + srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") + await seed_record( + engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": species, "mw": mw}, + published_at=published_at, + ) + await store.insert(SCHEMA, srn, {"species": species, "mw": mw}) + + +def _records_plan(filter_expr=None, limit=50, cursor=None) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + filter=filter_expr, + pagination=PaginationParams(limit=limit, cursor=cursor), + ) + + +async def _drain(store: PostgresTableReadStore, plan: QueryPlan) -> list[dict]: + return [dict(row) async for row in store.stream_rows(plan)] + + +@pytest.mark.asyncio +class TestStreamRecords: + async def test_streams_all_rows_flattened( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "rec1", "Homo sapiens", 1.5, datetime(2026, 1, 3, tzinfo=UTC) + ) + await _publish( + pg_engine, store, "rec2", "Mus musculus", 2.0, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresTableReadStore(pg_session) + rows = await _drain(rs, _records_plan()) + + assert len(rows) == 2 + # Default RECORDS sort is created_at desc → r1 first. + first = rows[0] + assert first["id"] == "rec1" + assert first["srn"] == "urn:osa:localhost:rec:rec1@1" + assert first["schema_id"] == "compound@1.0.0" + assert first["version"] == 1 + assert first["species"] == "Homo sapiens" + assert first["mw"] == 1.5 + assert "created_at" in first + + async def test_metadata_filter_narrows_results( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "rec1", "Homo sapiens", 1.5, datetime(2026, 1, 3, tzinfo=UTC) + ) + await _publish( + pg_engine, store, "rec2", "Mus musculus", 2.0, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresTableReadStore(pg_session) + plan = _records_plan( + filter_expr=Predicate( + field=MetadataFieldRef(field="species"), + op=FilterOperator.EQ, + value="Homo sapiens", + ) + ) + rows = await _drain(rs, plan) + assert [r["id"] for r in rows] == ["rec1"] + + async def test_empty_result_streams_nothing( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresTableReadStore(pg_session) + assert await _drain(rs, _records_plan()) == [] + + +@pytest.mark.asyncio +class TestGetRecordById: + async def test_resolves_by_bare_id_with_srn( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "abc", "Homo sapiens", 1.5, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + rec = await rs.get_record_by_id(RecordId("abc"), None) + assert rec is not None + assert str(rec.id) == "abc" + assert str(rec.srn) == "urn:osa:localhost:rec:abc@1" + assert rec.schema_id.render() == "compound@1.0.0" + assert rec.version == 1 + assert rec.metadata["species"] == "Homo sapiens" + + async def test_unknown_id_returns_none(self, pg_engine: AsyncEngine, pg_session: AsyncSession): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + assert await rs.get_record_by_id(RecordId("missing"), None) is None + + +@pytest.mark.asyncio +class TestCatalogAndManifest: + async def test_node_catalog_lists_schema( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + catalog = await rs.get_node_catalog() + assert catalog.node_domain == "localhost" + ids = {(e.id, e.version) for e in catalog.schemas} + assert ("compound", "1.0.0") in ids + + async def test_manifest_shape(self, pg_engine: AsyncEngine, pg_session: AsyncSession): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "rec1", "Homo sapiens", 1.5, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + manifest = await rs.get_schema_manifest(SCHEMA) + assert manifest is not None + assert manifest.id == "compound" + assert manifest.version == "1.0.0" + assert {f.name for f in manifest.fields} == {"species", "mw"} + records_res = next(t for t in manifest.table_resources if t.name == "records") + assert records_res.kind == TK.RECORDS + assert records_res.row_count == 1 + col_names = [c.name for c in records_res.columns] + # implicit columns precede the declared fields + assert col_names[:5] == ["id", "srn", "schema_id", "version", "created_at"] + assert "species" in col_names and "mw" in col_names + assert records_res.formats == ["", "csv", "csv.gz"] + + async def test_get_latest_schema_id(self, pg_engine: AsyncEngine, pg_session: AsyncSession): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) + resolved = await rs.get_latest_schema_id("compound") + assert resolved is not None + assert resolved.render() == "compound@1.0.0" + assert await rs.get_latest_schema_id("nonexistent") is None + + +@pytest.mark.asyncio +class TestStreamPaginationOrder: + async def test_cursor_advances_without_overlap( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + for i in range(5): + await _publish( + pg_engine, + store, + f"rec{i}", + "Homo sapiens", + float(i), + datetime(2026, 1, 1 + i, tzinfo=UTC), + ) + await pg_session.commit() + rs = PostgresTableReadStore(pg_session) + + # Sort by created_at desc (default). Page 1: take 2, derive cursor from row 2. + all_rows = await _drain(rs, _records_plan()) + assert [r["id"] for r in all_rows] == ["rec4", "rec3", "rec2", "rec1", "rec0"] + + # Cursor after the 2nd row (r3) → next page should start at r2. + cursor = encode_cursor(all_rows[1]["created_at"], all_rows[1]["srn"]) + plan2 = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination=PaginationParams(limit=50, cursor=PaginationCursor(value=cursor)), + sort=[SortSpec(column="created_at", direction=SortDirection.DESC)], + ) + page2 = await _drain(rs, plan2) + assert [r["id"] for r in page2] == ["rec2", "rec1", "rec0"] + + async def test_id_sort_cursor_round_trips( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + # ``sort=id`` aliases to the srn column; the route layer encodes the + # srn as the cursor sort value (see _next_cursor). A bare-id cursor + # would sort before every SRN and pagination would never advance. + store = await _setup_schema(pg_engine, pg_session) + for i in range(3): + await _publish( + pg_engine, + store, + f"rec{i}", + "Homo sapiens", + float(i), + datetime(2026, 1, 1, tzinfo=UTC), + ) + await pg_session.commit() + rs = PostgresTableReadStore(pg_session) + + sort = [SortSpec(column="id", direction=SortDirection.ASC)] + all_rows = await _drain( + rs, QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, sort=sort) + ) + assert [r["id"] for r in all_rows] == ["rec0", "rec1", "rec2"] + + cursor = encode_cursor(all_rows[1]["srn"], all_rows[1]["srn"]) + page2 = await _drain( + rs, + QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + sort=sort, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ), + ) + assert [r["id"] for r in page2] == ["rec2"] + + +ASSAY_SCHEMA = SchemaId.parse("assay@1.0.0") + + +@pytest.mark.asyncio +class TestMetadataDateCursor: + """Pagination sorted by a FieldType.DATE metadata column must round-trip. + + The cursor carries the sort value as an ISO string; the records sort must + coerce it back to a Python date before binding — asyncpg rejects a str + bound against the dynamic table's sa.Date column.""" + + async def test_date_sort_cursor_round_trips( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + fields = [ + FieldDefinition( + name="assay_date", + type=FieldType.DATE, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + store = PostgresMetadataStore(pg_engine, pg_session) + await store.ensure_table(ASSAY_SCHEMA, fields) + await PostgresSemanticsSchemaRepository(pg_session).save( + Schema(id=ASSAY_SCHEMA, title="assay", fields=fields, created_at=datetime.now(UTC)) + ) + for i in range(3): + srn = RecordSRN.parse(f"urn:osa:localhost:rec:arec{i}@1") + await seed_record( + pg_engine, + srn=str(srn), + schema_id=ASSAY_SCHEMA.id.root, + schema_version=ASSAY_SCHEMA.version.root, + metadata={"assay_date": f"2026-01-0{i + 1}"}, + published_at=datetime(2026, 1, 1, tzinfo=UTC), + ) + await store.insert(ASSAY_SCHEMA, srn, {"assay_date": f"2026-01-0{i + 1}"}) + await pg_session.commit() + + rs = PostgresTableReadStore(pg_session) + sort = [SortSpec(column="assay_date", direction=SortDirection.ASC)] + all_rows = await _drain( + rs, QueryPlan(schema_id=ASSAY_SCHEMA, table_kind=TableKind.RECORDS, sort=sort) + ) + assert [r["id"] for r in all_rows] == ["arec0", "arec1", "arec2"] + + # Encode the cursor as the route layer does — the date sort value + # serializes to an ISO string via default=str. + cursor = encode_cursor(all_rows[1]["assay_date"], all_rows[1]["srn"]) + page2 = await _drain( + rs, + QueryPlan( + schema_id=ASSAY_SCHEMA, + table_kind=TableKind.RECORDS, + sort=sort, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ), + ) + assert [r["id"] for r in page2] == ["arec2"] diff --git a/server/tests/integration/test_data_routes_e2e_postgres.py b/server/tests/integration/test_data_routes_e2e_postgres.py new file mode 100644 index 0000000..d2e14b3 --- /dev/null +++ b/server/tests/integration/test_data_routes_e2e_postgres.py @@ -0,0 +1,527 @@ +"""End-to-end HTTP tests for the /data/ routes against a real Postgres. + +Drives the full stack — FastAPI routing, DishkaRoute DI, the streaming +response, the serializers, and the SQL — through an ASGI client. Critically +validates that the request-scoped DB session stays alive through +``StreamingResponse`` iteration (the DishkaRoute + streaming interaction). + +Lifespan is intentionally NOT run (no worker pool needed): ``create_app`` +attaches the Dishka container to app state, so DI resolves without it. + +Skips unless OSA_DATABASE__URL points at PostgreSQL. +""" + +import gzip +import os +from datetime import UTC, datetime + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.srn import RecordSRN, SchemaId +from osa.infrastructure.persistence.feature_store import PostgresFeatureStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) +from osa.infrastructure.persistence.tables import conventions_table + +from tests.integration.conftest import seed_record + +# create_app() reads Config() at import/call time; localhost domain needs a base URL. +os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") +os.environ.setdefault("OSA_AUTH__JWT__SECRET", "test-secret-for-integration-tests-minimum-32-chars") + +if "postgresql" not in os.environ.get("OSA_DATABASE__URL", ""): + pytest.skip("OSA_DATABASE__URL not set to PostgreSQL", allow_module_level=True) + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + FieldDefinition( + name="mw", + type=FieldType.NUMBER, + required=False, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +async def _seed(engine: AsyncEngine, session: AsyncSession, n: int) -> None: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + for i in range(n): + srn = RecordSRN.parse(f"urn:osa:localhost:rec:rec{i:03d}@1") + await seed_record( + engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": "Homo sapiens", "mw": float(i)}, + published_at=datetime(2026, 1, 1 + i, tzinfo=UTC), + ) + await store.insert(SCHEMA, srn, {"species": "Homo sapiens", "mw": float(i)}) + await session.commit() + + +HOOK = "chem_features" + + +async def _seed_feature( + engine: AsyncEngine, session: AsyncSession, record_srn: RecordSRN, n: int +) -> None: + """Register a hook on the schema and populate ``features.chem_features``.""" + await session.execute( + conventions_table.insert().values( + srn=f"urn:osa:localhost:conv:{HOOK}@1.0.0", + title="compound conv", + description=None, + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + file_requirements={}, + hooks=[{"name": HOOK}], + source=None, + created_at=datetime.now(UTC), + ) + ) + await session.commit() + feature_store = PostgresFeatureStore(engine, session) + await feature_store.create_table( + HOOK, + [ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ], + ) + await feature_store.insert_features( + HOOK, + str(record_srn), + [{"score": float(i), "label": f"l{i}"} for i in range(n)], + ) + + +@pytest.fixture +def client() -> AsyncClient: + from osa.application.api.rest.app import create_app + + app = create_app() + return AsyncClient(transport=ASGITransport(app=app), base_url="http://test") + + +@pytest.mark.asyncio +class TestDataRoutesE2E: + async def test_records_csv_gz_streams_full_table( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/records.csv.gz") + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/gzip") + text = gzip.decompress(resp.content).decode() + lines = [ln for ln in text.splitlines() if ln] + assert lines[0].split(",")[:5] == ["id", "srn", "schema_id", "version", "created_at"] + assert len(lines) == 1 + 3 # header + 3 rows + + async def test_records_json_default_page( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/records") + assert resp.status_code == 200 + body = resp.json() + assert len(body["rows"]) == 3 + assert body["next_cursor"] is None + assert body["rows"][0]["species"] == "Homo sapiens" + + async def test_records_json_cursor_pagination( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 5) + async with client: + page1 = (await client.get("/api/v1/data/compound@1.0.0/records?limit=2")).json() + assert len(page1["rows"]) == 2 + assert page1["next_cursor"] is not None + page2 = ( + await client.get( + f"/api/v1/data/compound@1.0.0/records?limit=2&cursor={page1['next_cursor']}" + ) + ).json() + ids1 = {r["id"] for r in page1["rows"]} + ids2 = {r["id"] for r in page2["rows"]} + assert ids1.isdisjoint(ids2) # no overlap across pages + + async def test_post_csv_with_filter( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 4) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.mw", + "op": "gte", + "value": 2.0, + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv", json=body) + assert resp.status_code == 200 + lines = [ln for ln in resp.text.splitlines() if ln] + assert len(lines) == 1 + 2 # mw in {2.0, 3.0} + + async def test_catalog_and_manifest( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + catalog = (await client.get("/api/v1/data")).json() + assert any(s["id"] == "compound" for s in catalog["schemas"]) + manifest = (await client.get("/api/v1/data/compound@1.0.0")).json() + assert manifest["id"] == "compound" + names = [t["name"] for t in manifest["table_resources"]] + assert "records" in names + + async def test_record_by_id_has_id_and_srn( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + resp = await client.get("/api/v1/data/records/rec000") + assert resp.status_code == 200 + body = resp.json() + assert body["id"] == "rec000" + assert body["srn"] == "urn:osa:localhost:rec:rec000@1" + + async def test_reserved_and_unknown_schema_404( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + assert (await client.get("/api/v1/data/records")).status_code == 404 + assert (await client.get("/api/v1/data/datasets")).status_code == 404 + assert (await client.get("/api/v1/data/nope@9.9.9")).status_code == 404 + + async def test_record_by_id_version_pin( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + store = PostgresMetadataStore(pg_engine, pg_session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(pg_session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + for version, mw in ((1, 10.0), (2, 20.0)): + srn = RecordSRN.parse(f"urn:osa:localhost:rec:dup@{version}") + await seed_record( + pg_engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": "Homo sapiens", "mw": mw}, + published_at=datetime(2026, 1, version, tzinfo=UTC), + ) + await store.insert(SCHEMA, srn, {"species": "Homo sapiens", "mw": mw}) + await pg_session.commit() + + async with client: + v1 = (await client.get("/api/v1/data/records/dup@1")).json() + v2 = (await client.get("/api/v1/data/records/dup@2")).json() + latest = (await client.get("/api/v1/data/records/dup")).json() + assert v1["srn"] == "urn:osa:localhost:rec:dup@1" + assert v2["srn"] == "urn:osa:localhost:rec:dup@2" + # No version → latest published (@2, published 2026-01-02). + assert latest["srn"] == "urn:osa:localhost:rec:dup@2" + + +@pytest.mark.asyncio +class TestRecordsStreamingEdges: + async def test_post_csv_gz_happy_path( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 4) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.mw", + "op": "gte", + "value": 2.0, + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv.gz", json=body) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/gzip") + lines = [ln for ln in gzip.decompress(resp.content).decode().splitlines() if ln] + assert lines[0].split(",")[:5] == ["id", "srn", "schema_id", "version", "created_at"] + assert len(lines) == 1 + 2 # header + mw in {2.0, 3.0} + + async def test_empty_result_returns_gzipped_header_only( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 2) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.species", + "op": "eq", + "value": "Nonexistent species", + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv.gz", json=body) + assert resp.status_code == 200 + lines = [ln for ln in gzip.decompress(resp.content).decode().splitlines() if ln] + assert len(lines) == 1 # header only, no data rows + assert lines[0].split(",")[:5] == ["id", "srn", "schema_id", "version", "created_at"] + + async def test_invalid_filter_field_4xx_before_bytes( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 2) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.nonexistent_field", + "op": "eq", + "value": "x", + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv.gz", json=body) + # The error must be surfaced before any streamed bytes (pre-flight). + assert 400 <= resp.status_code < 500 + assert not resp.headers.get("content-type", "").startswith("application/gzip") + + async def test_unknown_schema_404_before_bytes( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + resp = await client.get("/api/v1/data/nope@9.9.9/records.csv.gz") + assert resp.status_code == 404 + assert not resp.headers.get("content-type", "").startswith("application/gzip") + + +@pytest.mark.asyncio +class TestRecordsJsonEdges: + async def test_limit_over_max_is_clamped( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/records?limit=5000") + assert resp.status_code == 200 # clamped, not rejected + assert len(resp.json()["rows"]) == 3 + + async def test_sort_param_orders_rows( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get( + "/api/v1/data/compound@1.0.0/records?sort=created_at:asc,id:asc" + ) + rows = resp.json()["rows"] + # created_at ascending → rec000 (earliest) first. + assert rows[0]["id"] == "rec000" + assert [r["id"] for r in rows] == sorted(r["id"] for r in rows) + + async def test_post_json_with_filter( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 4) + body = {"filter": {"kind": "predicate", "field": "metadata.mw", "op": "lt", "value": 2.0}} + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records", json=body) + assert resp.status_code == 200 + rows = resp.json()["rows"] + assert {r["id"] for r in rows} == {"rec000", "rec001"} # mw 0.0, 1.0 + + async def test_boolean_filter_composition( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + # Exercises AND / OR / NOT SQL compilation (incl. the NOT→coalesce path) + # that the relocated filter engine carries over from discovery. + await _seed(pg_engine, pg_session, 5) # mw ∈ {0,1,2,3,4} + body = { + "filter": { + "kind": "and", + "operands": [ + { + "kind": "or", + "operands": [ + { + "kind": "predicate", + "field": "metadata.mw", + "op": "gte", + "value": 3.0, + }, + {"kind": "predicate", "field": "metadata.mw", "op": "lt", "value": 1.0}, + ], + }, + { + "kind": "not", + "operand": { + "kind": "predicate", + "field": "metadata.mw", + "op": "eq", + "value": 4.0, + }, + }, + ], + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records", json=body) + assert resp.status_code == 200 + rows = resp.json()["rows"] + # OR(mw≥3, mw<1) = {0,3,4}; AND NOT(mw=4) → {0,3}. + assert {r["id"] for r in rows} == {"rec000", "rec003"} + + +@pytest.mark.asyncio +class TestCatalogEdges: + async def test_empty_node_returns_200_empty_schemas(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/data") + assert resp.status_code == 200 + body = resp.json() + assert body["schemas"] == [] + assert "node_domain" in body + + +@pytest.mark.asyncio +class TestLegacySurfaceRemoved: + """US6: the legacy read surface is gone — every old path 404s.""" + + async def test_legacy_discovery_route_404(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/discovery/records") + assert resp.status_code == 404 + + async def test_legacy_records_srn_route_404(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/records/urn:osa:localhost:rec:anything@1") + assert resp.status_code == 404 + + async def test_legacy_search_route_404(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/search/records?q=anything") + assert resp.status_code == 404 + + async def test_stats_has_no_index_field( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + async with client: + resp = await client.get("/api/v1/stats") + assert resp.status_code == 200 + body = resp.json() + assert "indexes" not in body + assert body["data_url"] == "/api/v1/data" + + +@pytest.mark.asyncio +class TestPostRateLimit: + async def test_eleventh_post_within_a_minute_is_429( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + from osa.application.api.v1.routes.data._limiter import limiter + + # The limiter is a module singleton shared across tests — reset its + # window so this test's count starts clean and doesn't leak outward. + limiter.reset() + await _seed(pg_engine, pg_session, 1) + statuses = [] + async with client: + for _ in range(11): # limit is 10/minute on POST routes + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv", json={}) + statuses.append(resp.status_code) + limiter.reset() + assert statuses[:10] == [200] * 10 + assert statuses[10] == 429 + + +@pytest.mark.asyncio +class TestFeatureRoutesE2E: + async def test_feature_csv_gz_streams_full_table( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 3) + async with client: + resp = await client.get(f"/api/v1/data/compound@1.0.0/{HOOK}.csv.gz") + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/gzip") + lines = [ln for ln in gzip.decompress(resp.content).decode().splitlines() if ln] + assert lines[0].split(",")[:3] == ["id", "record_srn", "created_at"] + assert "score" in lines[0] and "label" in lines[0] + assert len(lines) == 1 + 3 # header + 3 rows + + async def test_feature_post_csv_with_filter( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 4) + body = { + "filter": { + "kind": "predicate", + "field": f"features.{HOOK}.score", + "op": "gte", + "value": 2.0, + } + } + async with client: + resp = await client.post(f"/api/v1/data/compound@1.0.0/{HOOK}.csv", json=body) + assert resp.status_code == 200 + lines = [ln for ln in resp.text.splitlines() if ln] + assert len(lines) == 1 + 2 # score in {2.0, 3.0} + + async def test_feature_json_default_page( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 3) + async with client: + resp = await client.get(f"/api/v1/data/compound@1.0.0/{HOOK}") + assert resp.status_code == 200 + body = resp.json() + assert len(body["rows"]) == 3 + assert body["rows"][0]["label"] == "l0" + + async def test_unknown_feature_404( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/no_such_feature.csv.gz") + assert resp.status_code == 404 + + async def test_manifest_lists_feature_table( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 2) + async with client: + manifest = (await client.get("/api/v1/data/compound@1.0.0")).json() + feature = next(t for t in manifest["table_resources"] if t["name"] == HOOK) + assert feature["kind"] == "feature" + assert feature["row_count"] == 2 diff --git a/server/tests/integration/test_data_streaming_guarantees_postgres.py b/server/tests/integration/test_data_streaming_guarantees_postgres.py new file mode 100644 index 0000000..5240562 --- /dev/null +++ b/server/tests/integration/test_data_streaming_guarantees_postgres.py @@ -0,0 +1,153 @@ +"""Integration tests for the streaming engine's resource guarantees (T041, T042). + +These exercise the two non-functional promises of the ``/data/`` streaming path +directly against the read store + a real Postgres server-side cursor: + +* **Bounded memory (T041)** — streaming a large table must not materialize all + rows in the Python heap. Asserted with ``tracemalloc`` peak while draining the + dump without accumulating, against a bulk-seeded table. (The literal spec's + absolute "<50 MB RSS" is not meaningful in-process — the test runner's RSS is + already well above that from imports — so we assert *Python-heap peak stays + small*, which is the guarantee the streaming primitive actually provides.) + +* **Cursor release on early termination (T042)** — closing the async generator + mid-stream (the client-disconnect path: a ``GeneratorExit`` thrown into the + body) must run the ``try/finally`` that closes the server-side cursor, leaving + the connection immediately reusable. (Simulating a real TCP disconnect + + ``pg_stat_activity`` poll is not reliable through the in-process ASGI client, + so we drive the generator's lifecycle directly, which is where the guarantee + lives.) + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +import tracemalloc +from datetime import UTC, datetime + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.data.model.query_plan import PaginationParams, QueryPlan, TableKind +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + FieldDefinition( + name="mw", + type=FieldType.NUMBER, + required=False, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +async def _setup_schema(engine: AsyncEngine, session: AsyncSession) -> None: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + await session.commit() + + +async def _bulk_seed(engine: AsyncEngine, n: int) -> None: + """Insert *n* records + metadata rows in two set-based statements. + + ``source.id`` is per-row unique to satisfy the records source uniqueness + index; otherwise a single bulk INSERT would collide on the second row. + """ + async with engine.begin() as conn: + await conn.execute( + text( + """ + INSERT INTO records + (srn, convention_srn, schema_id, schema_version, source, metadata, published_at) + SELECT 'urn:osa:localhost:rec:bulk' || g || '@1', + 'urn:osa:localhost:conv:test@1.0.0', 'compound', '1.0.0', + jsonb_build_object('type', 'deposition', 'id', 'bulk' || g), + jsonb_build_object('species', 'Homo sapiens', 'mw', g), + now() + FROM generate_series(1, :n) g + """ + ), + {"n": n}, + ) + await conn.execute( + text( + """ + INSERT INTO metadata.compound_v1 (record_srn, species, mw) + SELECT 'urn:osa:localhost:rec:bulk' || g || '@1', 'Homo sapiens', g + FROM generate_series(1, :n) g + """ + ), + {"n": n}, + ) + + +def _records_plan(limit: int = 1000) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination=PaginationParams(limit=limit), + ) + + +@pytest.mark.asyncio +class TestStreamingGuarantees: + async def test_large_dump_does_not_materialize_in_python_heap( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + n = 20_000 + await _setup_schema(pg_engine, pg_session) + await _bulk_seed(pg_engine, n) + rs = PostgresTableReadStore(pg_session) + + tracemalloc.start() + count = 0 + # Drain WITHOUT accumulating — the engine must not hold all rows at once. + async for _row in rs.stream_rows(_records_plan()): + count += 1 + _current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + assert count == n + # If the engine buffered all 20K flattened dicts, peak would be many MB. + # A streaming server-side cursor keeps the Python-heap peak small. + assert peak < 8 * 1024 * 1024, f"peak heap {peak} bytes — streaming may be buffering" + + async def test_early_generator_close_releases_server_cursor( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await _bulk_seed(pg_engine, 500) + rs = PostgresTableReadStore(pg_session) + + gen = rs.stream_rows(_records_plan()) + first = await gen.__anext__() # opens the server-side cursor + assert first["schema_id"] == "compound@1.0.0" + # Simulate client disconnect: closing the generator throws GeneratorExit + # into the body, which must run the finally that closes the cursor. + await gen.aclose() + + # If the cursor leaked (portal still open), a follow-up query on the same + # connection would raise. A clean full drain proves the connection was + # returned to a usable state. + rows = [r async for r in rs.stream_rows(_records_plan())] + assert len(rows) == 500 diff --git a/server/tests/integration/test_discovery_compound_postgres.py b/server/tests/integration/test_discovery_compound_postgres.py deleted file mode 100644 index f752c0f..0000000 --- a/server/tests/integration/test_discovery_compound_postgres.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Integration tests for compound OR/NOT discovery filters against real PG.""" - -import pytest -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession - -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - FilterOperator, - Not, - Or, - Predicate, - SortOrder, -) -from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.persistence.adapter.discovery import PostgresDiscoveryReadStore -from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore - -from tests.integration.conftest import seed_record - -SCHEMA_V1 = SchemaId.parse("bio-sample@1.0.0") -FIELD_TYPES = {"species": FieldType.TEXT, "resolution": FieldType.NUMBER} - - -def _fields() -> list[FieldDefinition]: - return [ - FieldDefinition( - name="species", - type=FieldType.TEXT, - required=True, - cardinality=Cardinality.EXACTLY_ONE, - ), - FieldDefinition( - name="resolution", - type=FieldType.NUMBER, - required=False, - cardinality=Cardinality.EXACTLY_ONE, - ), - ] - - -@pytest.fixture -async def seeded_store(pg_engine: AsyncEngine, pg_session: AsyncSession) -> PostgresMetadataStore: - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - - repo = PostgresSemanticsSchemaRepository(pg_session) - await repo.save( - Schema(id=SCHEMA_V1, title="bio_sample", fields=_fields(), created_at=datetime.now(UTC)) - ) - - rows = [ - ("rec-a1", "Homo sapiens", 3.5), - ("rec-b1", "Homo sapiens", 1.0), - ("rec-c1", "Mus musculus", 3.5), - ("rec-d1", "Drosophila", 0.5), - ] - for rid, sp, res in rows: - srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") - await seed_record( - pg_engine, - srn=str(srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - ) - await store.insert(SCHEMA_V1, srn, {"species": sp, "resolution": res}) - - await pg_session.commit() - return store - - -def _pred(field: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=op, value=value) - - -@pytest.mark.asyncio -class TestCompound: - async def test_or_tree(self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_store): - read_store = PostgresDiscoveryReadStore(pg_session) - # species = Homo sapiens OR resolution < 1.0 - tree = Or( - operands=[ - _pred("species", FilterOperator.EQ, "Homo sapiens"), - _pred("resolution", FilterOperator.LT, 1.0), - ] - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - # a, b (species match) + d (resolution 0.5) — not c - assert srns == { - "urn:osa:localhost:rec:rec-a1@1", - "urn:osa:localhost:rec:rec-b1@1", - "urn:osa:localhost:rec:rec-d1@1", - } - - async def test_not_tree(self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_store): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Not(operand=_pred("species", FilterOperator.EQ, "Homo sapiens")) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-c1@1", "urn:osa:localhost:rec:rec-d1@1"} - - async def test_nested_and_or( - self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_store - ): - read_store = PostgresDiscoveryReadStore(pg_session) - # resolution >= 3.0 AND (species = Homo sapiens OR species = Mus musculus) - tree = And( - operands=[ - _pred("resolution", FilterOperator.GTE, 3.0), - Or( - operands=[ - _pred("species", FilterOperator.EQ, "Homo sapiens"), - _pred("species", FilterOperator.EQ, "Mus musculus"), - ] - ), - ] - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-a1@1", "urn:osa:localhost:rec:rec-c1@1"} diff --git a/server/tests/integration/test_discovery_cross_join_postgres.py b/server/tests/integration/test_discovery_cross_join_postgres.py deleted file mode 100644 index c2ed9ec..0000000 --- a/server/tests/integration/test_discovery_cross_join_postgres.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Integration tests for cross-domain JOINs between records ⋈ features in discovery.""" - -import pytest -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession - -from osa.domain.discovery.model.refs import FeatureFieldRef, MetadataFieldRef -from osa.domain.discovery.model.value import And, FilterOperator, Not, Predicate, SortOrder -from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.hook import ColumnDef -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.persistence.adapter.discovery import PostgresDiscoveryReadStore -from osa.infrastructure.persistence.feature_store import PostgresFeatureStore -from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore - -from tests.integration.conftest import seed_record - -SCHEMA_V1 = SchemaId.parse("bio-sample@1.0.0") -FIELD_TYPES = {"species": FieldType.TEXT} - - -def _metadata_fields() -> list[FieldDefinition]: - return [ - FieldDefinition( - name="species", - type=FieldType.TEXT, - required=True, - cardinality=Cardinality.EXACTLY_ONE, - ), - ] - - -def _feature_columns() -> list[ColumnDef]: - return [ - ColumnDef(name="confidence", json_type="number", required=True), - ] - - -@pytest.fixture -async def seeded_both(pg_engine: AsyncEngine, pg_session: AsyncSession): - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - mstore = PostgresMetadataStore(pg_engine, pg_session) - await mstore.ensure_table(SCHEMA_V1, _metadata_fields()) - - fstore = PostgresFeatureStore(pg_engine, pg_session) - await fstore.create_table("cell_classifier", _feature_columns()) - - repo = PostgresSemanticsSchemaRepository(pg_session) - await repo.save( - Schema( - id=SCHEMA_V1, - title="bio_sample", - fields=_metadata_fields(), - created_at=datetime.now(UTC), - ) - ) - - # r1: Homo sapiens + confidence 0.95 - # r2: Homo sapiens + confidence 0.5 - # r3: Mus musculus + confidence 0.95 - for rid, sp, conf in [ - ("rec-r1", "Homo sapiens", 0.95), - ("rec-r2", "Homo sapiens", 0.5), - ("rec-r3", "Mus musculus", 0.95), - ]: - srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") - await seed_record( - pg_engine, - srn=str(srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - ) - await mstore.insert(SCHEMA_V1, srn, {"species": sp}) - await fstore.insert_features("cell_classifier", str(srn), [{"confidence": conf}]) - - await pg_session.commit() - return mstore, fstore - - -@pytest.mark.asyncio -class TestCrossDomainJoin: - async def test_joined_intersection( - self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_both - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = And( - operands=[ - Predicate( - field=MetadataFieldRef(field="species"), - op=FilterOperator.EQ, - value="Homo sapiens", - ), - Predicate( - field=FeatureFieldRef(hook="cell_classifier", column="confidence"), - op=FilterOperator.GT, - value=0.9, - ), - ] - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-r1@1"} - - async def test_unknown_hook_raises( - self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_both - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Predicate( - field=FeatureFieldRef(hook="does_not_exist", column="anything"), - op=FilterOperator.EQ, - value=1, - ) - with pytest.raises(ValidationError, match="Unknown feature hook"): - await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - - -@pytest.fixture -async def seeded_with_missing_feature_row(pg_engine: AsyncEngine, pg_session: AsyncSession): - """Seed a record with a metadata row but NO feature row, so the outer - join produces NULL feature columns for that record.""" - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - mstore = PostgresMetadataStore(pg_engine, pg_session) - await mstore.ensure_table(SCHEMA_V1, _metadata_fields()) - - fstore = PostgresFeatureStore(pg_engine, pg_session) - await fstore.create_table("cell_classifier", _feature_columns()) - - repo = PostgresSemanticsSchemaRepository(pg_session) - await repo.save( - Schema( - id=SCHEMA_V1, - title="bio_sample", - fields=_metadata_fields(), - created_at=datetime.now(UTC), - ) - ) - - # rec-has-feature: has a feature row with confidence 0.95. - # rec-no-feature: no feature row at all (outer join will produce NULLs). - for rid, sp in [("rec-has-feature", "Homo sapiens"), ("rec-no-feature", "Mus musculus")]: - srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") - await seed_record( - pg_engine, - srn=str(srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - ) - await mstore.insert(SCHEMA_V1, srn, {"species": sp}) - - has_feature_srn = "urn:osa:localhost:rec:rec-has-feature@1" - await fstore.insert_features("cell_classifier", has_feature_srn, [{"confidence": 0.95}]) - - await pg_session.commit() - - -@pytest.mark.asyncio -class TestOuterJoinNullHandling: - """Records without a feature row must not be silently dropped from NEQ/NOT - predicates on feature columns — the outer join produces NULL, and naive - SQL three-valued logic would exclude them.""" - - async def test_neq_on_feature_column_includes_missing_rows( - self, - pg_engine: AsyncEngine, - pg_session: AsyncSession, - seeded_with_missing_feature_row, - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Predicate( - field=FeatureFieldRef(hook="cell_classifier", column="confidence"), - op=FilterOperator.NEQ, - value=0.95, - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - # rec-no-feature has no feature row → confidence is NULL → "!= 0.95" - # should include it. rec-has-feature has confidence 0.95 → excluded. - assert srns == {"urn:osa:localhost:rec:rec-no-feature@1"} - - async def test_not_on_feature_column_includes_missing_rows( - self, - pg_engine: AsyncEngine, - pg_session: AsyncSession, - seeded_with_missing_feature_row, - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Not( - operand=Predicate( - field=FeatureFieldRef(hook="cell_classifier", column="confidence"), - op=FilterOperator.EQ, - value=0.95, - ) - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - # Same invariant as NEQ: NOT(confidence = 0.95) must surface the - # record with a missing feature row. - assert srns == {"urn:osa:localhost:rec:rec-no-feature@1"} diff --git a/server/tests/integration/test_discovery_records_typed_and.py b/server/tests/integration/test_discovery_records_typed_and.py deleted file mode 100644 index b9487c5..0000000 --- a/server/tests/integration/test_discovery_records_typed_and.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Integration tests for /discovery/records with typed-table AND filters.""" - -import pytest -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession - -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import And, FilterOperator, Predicate, SortOrder -from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.persistence.adapter.discovery import ( - PostgresDiscoveryReadStore, - PostgresFieldDefinitionReader, -) -from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore - -from tests.integration.conftest import seed_record - -SCHEMA_V1 = SchemaId.parse("bio-sample@1.0.0") - - -def _fields() -> list[FieldDefinition]: - return [ - FieldDefinition( - name="species", - type=FieldType.TEXT, - required=True, - cardinality=Cardinality.EXACTLY_ONE, - ), - FieldDefinition( - name="resolution", - type=FieldType.NUMBER, - required=False, - cardinality=Cardinality.EXACTLY_ONE, - ), - FieldDefinition( - name="method", - type=FieldType.TEXT, - required=False, - cardinality=Cardinality.EXACTLY_ONE, - ), - ] - - -async def _seed_schema_row(session: AsyncSession) -> None: - """Seed the `schemas` row so the discovery field reader can resolve types.""" - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - repo = PostgresSemanticsSchemaRepository(session) - await repo.save( - Schema(id=SCHEMA_V1, title="bio_sample", fields=_fields(), created_at=datetime.now(UTC)) - ) - - -async def _publish( - engine: AsyncEngine, - session: AsyncSession, - store: PostgresMetadataStore, - record_srn: RecordSRN, - species: str, - resolution: float, - method: str, -) -> None: - await seed_record( - engine, - srn=str(record_srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - metadata={"species": species, "resolution": resolution, "method": method}, - ) - await store.insert( - SCHEMA_V1, - record_srn, - {"species": species, "resolution": resolution, "method": method}, - ) - - -@pytest.mark.asyncio -class TestDiscoveryTypedAnd: - async def test_and_filter_returns_matching_records( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - - rows = [ - ("rec-r1", "Homo sapiens", 3.5, "cryo-EM"), - ("rec-r2", "Homo sapiens", 1.8, "X-ray"), - ("rec-r3", "Mus musculus", 3.0, "cryo-EM"), - ] - for rid, sp, res, meth in rows: - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1"), - sp, - res, - meth, - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - tree = And( - operands=[ - Predicate( - field=MetadataFieldRef(field="species"), - op=FilterOperator.EQ, - value="Homo sapiens", - ), - Predicate( - field=MetadataFieldRef(field="resolution"), - op=FilterOperator.GTE, - value=2.0, - ), - ] - ) - - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={ - "species": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TEXT, - }, - ) - - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-r1@1"} - - async def test_scalar_op_succeeds_on_unindexed_column( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - """FR-020: scalar ops must NOT be rejected for lack of index.""" - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse("urn:osa:localhost:rec:rec-ra@1"), - "Homo sapiens", - 3.5, - "cryo-EM", - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - results = await read_store.search_records( - filter_expr=Predicate( - field=MetadataFieldRef(field="method"), - op=FilterOperator.CONTAINS, - value="cryo", - ), - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={ - "species": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TEXT, - }, - ) - assert len(results) == 1 - - -@pytest.mark.asyncio -class TestUnscopedListing: - """Plain listings without a filter return canonical JSONB metadata. - Metadata-filtered queries require a pinned schema — the typed table is - the only filter path.""" - - async def test_unscoped_predicate_filter_raises_without_schema( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - """Filtering by a metadata field without schema_id must raise — - the JSONB fallback compile path was removed.""" - from osa.domain.discovery.model.refs import MetadataFieldRef - from osa.domain.discovery.model.value import FilterOperator, Predicate - from osa.domain.shared.error import ValidationError - - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse("urn:osa:localhost:rec:rec-9x1w@1"), - "Homo sapiens", - 3.5, - "cryo-EM", - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - with pytest.raises(ValidationError) as exc: - await read_store.search_records( - filter_expr=Predicate( - field=MetadataFieldRef(field="species"), - op=FilterOperator.EQ, - value="Homo sapiens", - ), - schema_id=None, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={"species": FieldType.TEXT}, - ) - assert exc.value.code == "schema_required_for_metadata_query" - - async def test_unscoped_listing_returns_jsonb_metadata( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse("urn:osa:localhost:rec:rec-unscoped@1"), - "Homo sapiens", - 3.5, - "cryo-EM", - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - results = await read_store.search_records( - filter_expr=None, - schema_id=None, # deliberately unscoped — exercises the JSONB path - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={}, - ) - assert len(results) == 1 - assert results[0].metadata == { - "species": "Homo sapiens", - "resolution": 3.5, - "method": "cryo-EM", - } - - -@pytest.mark.asyncio -class TestFieldDefinitionReader: - async def test_get_fields_for_schema(self, pg_engine: AsyncEngine, pg_session: AsyncSession): - await _seed_schema_row(pg_session) - await pg_session.commit() - - reader = PostgresFieldDefinitionReader(pg_session) - fields = await reader.get_fields_for_schema(SCHEMA_V1) - assert fields["species"] == FieldType.TEXT - assert fields["resolution"] == FieldType.NUMBER diff --git a/server/tests/integration/test_event_batch_processing.py b/server/tests/integration/test_event_batch_processing.py deleted file mode 100644 index e1f6754..0000000 --- a/server/tests/integration/test_event_batch_processing.py +++ /dev/null @@ -1,229 +0,0 @@ -"""Integration tests for batch event processing flow. - -Tests the end-to-end flow from RecordPublished -> IndexRecord fan-out -> batch processing. -""" - -from typing import Any -from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import ( - ConventionSRN, - DepositionSRN, - Domain, - LocalId, - RecordSRN, - RecordVersion, -) - - -class FakeBackend: - """Fake storage backend that tracks batch calls.""" - - def __init__(self, name: str): - self._name = name - self.batch_calls: list[list[tuple[str, dict]]] = [] - self.total_records: int = 0 - - @property - def name(self) -> str: - return self._name - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - """Track batch calls for verification.""" - self.batch_calls.append(list(records)) - self.total_records += len(records) - - -class FakeOutbox: - """Fake outbox that collects emitted events.""" - - def __init__(self): - self.events: list[Any] = [] - - async def append(self, event: Any) -> None: - self.events.append(event) - - -def make_record_published( - record_id: str | None = None, - metadata: dict | None = None, -) -> RecordPublished: - """Create a RecordPublished event for testing.""" - from osa.domain.shared.model.source import DepositionSource - - dep_srn = DepositionSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - ) - from osa.domain.shared.model.srn import SchemaId - - return RecordPublished( - id=EventId(uuid4()), - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(record_id or str(uuid4())), - version=RecordVersion(1), - ), - source=DepositionSource(id=str(dep_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=SchemaId.parse("test@1.0.0"), - metadata=metadata or {"title": "Test Record"}, - ) - - -class TestBatchEventProcessingFlow: - """Integration tests for the batch event processing flow.""" - - @pytest.mark.asyncio - async def test_fanout_creates_index_records_per_backend(self): - """FanOutToIndexBackends should create one IndexRecord per backend.""" - # Arrange - backends = { - "vector": FakeBackend("vector"), - "keyword": FakeBackend("keyword"), - } - registry = IndexRegistry(backends) - outbox = FakeOutbox() - - fanout = FanOutToIndexBackends(indexes=registry, outbox=outbox) - event = make_record_published() - - # Act - await fanout.handle(event) - - # Assert - assert len(outbox.events) == 2 - backend_names = {e.backend_name for e in outbox.events} - assert backend_names == {"vector", "keyword"} - - @pytest.mark.asyncio - async def test_handler_processes_batch(self): - """VectorIndexHandler should process events in batches.""" - # Arrange - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - handler = VectorIndexHandler(indexes=registry) - - # Create events for vector backend - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"id": i}, - ) - for i in range(5) - ] - - # Act - await handler.handle_batch(events) - - # Assert - vector backend received 5 records in single batch call - assert len(vector_backend.batch_calls) == 1 - assert len(vector_backend.batch_calls[0]) == 5 - assert vector_backend.total_records == 5 - - @pytest.mark.asyncio - async def test_end_to_end_fanout_to_batch(self): - """End-to-end: RecordPublished -> FanOut -> BatchProcess.""" - # Arrange - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - outbox = FakeOutbox() - - fanout = FanOutToIndexBackends(indexes=registry, outbox=outbox) - handler = VectorIndexHandler(indexes=registry) - - # Create multiple RecordPublished events - num_records = 10 - published_events = [make_record_published() for _ in range(num_records)] - - # Act - Step 1: FanOut each RecordPublished - for event in published_events: - await fanout.handle(event) - - # Verify IndexRecord events were created - assert len(outbox.events) == num_records - assert all(isinstance(e, IndexRecord) for e in outbox.events) - assert all(e.backend_name == "vector" for e in outbox.events) - - # Act - Step 2: Batch process all IndexRecord events - await handler.handle_batch(outbox.events) - - # Assert - All records indexed in single batch call - assert len(vector_backend.batch_calls) == 1 - assert vector_backend.total_records == num_records - - @pytest.mark.asyncio - async def test_batch_efficiency_large_batch(self): - """Verify batch processing is efficient with large batches.""" - # Arrange - backend = FakeBackend("vector") - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Create 1000+ events - num_events = 1000 - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"index": i, "title": f"Record {i}"}, - ) - for i in range(num_events) - ] - - # Act - await handler.handle_batch(events) - - # Assert - All records in single batch call (not 1000 individual calls) - assert len(backend.batch_calls) == 1, "Should use single batch call, not individual calls" - assert backend.total_records == num_events - assert len(backend.batch_calls[0]) == num_events - - @pytest.mark.asyncio - async def test_failure_propagates_from_handler(self): - """Verify failures propagate from handler (Worker handles retry).""" - # Arrange - VectorIndexHandler looks up "vector" backend specifically - vector_backend = MagicMock() - vector_backend.name = "vector" - vector_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend failure")) - - registry = IndexRegistry({"vector": vector_backend}) - handler = VectorIndexHandler(indexes=registry) - - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"id": i}, - ) - for i in range(2) - ] - - # Process events - should raise (Worker handles retry) - with pytest.raises(Exception, match="Backend failure"): - await handler.handle_batch(events) diff --git a/server/osa/domain/export/model/__init__.py b/server/tests/unit/application/api/v1/routes/__init__.py similarity index 100% rename from server/osa/domain/export/model/__init__.py rename to server/tests/unit/application/api/v1/routes/__init__.py diff --git a/server/osa/domain/export/port/__init__.py b/server/tests/unit/application/api/v1/routes/data/__init__.py similarity index 100% rename from server/osa/domain/export/port/__init__.py rename to server/tests/unit/application/api/v1/routes/data/__init__.py diff --git a/server/osa/domain/export/query/__init__.py b/server/tests/unit/application/api/v1/routes/data/serializers/__init__.py similarity index 100% rename from server/osa/domain/export/query/__init__.py rename to server/tests/unit/application/api/v1/routes/data/serializers/__init__.py diff --git a/server/tests/unit/application/api/v1/routes/data/serializers/test_csv.py b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv.py new file mode 100644 index 0000000..f0bdb9d --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv.py @@ -0,0 +1,64 @@ +"""T031 — CsvSerializer header row, quoting, empty result.""" + +import csv +import io +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec +from osa.application.api.v1.routes.data.serializers.csv import CsvSerializer +from osa.domain.semantics.model.value import FieldType + +COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="name", type=FieldType.TEXT), +] + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _collect(gen: AsyncIterator[bytes]) -> str: + out = b"" + async for chunk in gen: + out += chunk + return out.decode() + + +@pytest.mark.asyncio +async def test_csv_header_and_rows() -> None: + rows = [{"id": "a", "name": "alpha"}, {"id": "b", "name": "beta"}] + text = await _collect(CsvSerializer().stream(_aiter(rows), COLUMNS)) + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[0] == ["id", "name"] + assert parsed[1] == ["a", "alpha"] + assert parsed[2] == ["b", "beta"] + + +@pytest.mark.asyncio +async def test_csv_empty_result_is_header_only() -> None: + text = await _collect(CsvSerializer().stream(_aiter([]), COLUMNS)) + parsed = list(csv.reader(io.StringIO(text))) + assert parsed == [["id", "name"]] + + +@pytest.mark.asyncio +async def test_csv_quotes_values_with_commas() -> None: + rows = [{"id": "a", "name": "beta, gamma"}] + text = await _collect(CsvSerializer().stream(_aiter(rows), COLUMNS)) + # QUOTE_MINIMAL wraps the comma-bearing value + assert '"beta, gamma"' in text + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[1] == ["a", "beta, gamma"] + + +@pytest.mark.asyncio +async def test_csv_missing_column_is_empty() -> None: + rows = [{"id": "a"}] # no "name" + text = await _collect(CsvSerializer().stream(_aiter(rows), COLUMNS)) + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[1] == ["a", ""] diff --git a/server/tests/unit/application/api/v1/routes/data/serializers/test_csv_gzip.py b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv_gzip.py new file mode 100644 index 0000000..1d81eee --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv_gzip.py @@ -0,0 +1,56 @@ +"""T032 — CsvGzipSerializer gzip-stream round-trip and magic bytes.""" + +import csv +import gzip +import io +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec +from osa.application.api.v1.routes.data.serializers.csv_gzip import CsvGzipSerializer +from osa.domain.semantics.model.value import FieldType + +COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="name", type=FieldType.TEXT), +] + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _collect(gen: AsyncIterator[bytes]) -> bytes: + out = b"" + async for chunk in gen: + out += chunk + return out + + +@pytest.mark.asyncio +async def test_gzip_magic_bytes() -> None: + body = await _collect(CsvGzipSerializer().stream(_aiter([{"id": "a", "name": "x"}]), COLUMNS)) + # gzip stream header: 0x1f 0x8b 0x08 + assert body[:3] == b"\x1f\x8b\x08" + + +@pytest.mark.asyncio +async def test_gzip_roundtrip_through_gunzip() -> None: + rows = [{"id": "a", "name": "alpha"}, {"id": "b", "name": "beta"}] + body = await _collect(CsvGzipSerializer().stream(_aiter(rows), COLUMNS)) + text = gzip.decompress(body).decode() + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[0] == ["id", "name"] + assert parsed[1] == ["a", "alpha"] + assert parsed[2] == ["b", "beta"] + + +@pytest.mark.asyncio +async def test_gzip_empty_result_roundtrips_to_header() -> None: + body = await _collect(CsvGzipSerializer().stream(_aiter([]), COLUMNS)) + text = gzip.decompress(body).decode() + parsed = list(csv.reader(io.StringIO(text))) + assert parsed == [["id", "name"]] diff --git a/server/tests/unit/application/api/v1/routes/data/serializers/test_json.py b/server/tests/unit/application/api/v1/routes/data/serializers/test_json.py new file mode 100644 index 0000000..084122e --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/serializers/test_json.py @@ -0,0 +1,53 @@ +"""T030 — JsonSerializer paginated envelope, empty result, cursor embedding.""" + +import json +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec +from osa.application.api.v1.routes.data.serializers.json import JsonSerializer +from osa.domain.semantics.model.value import FieldType + +COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="mw", type=FieldType.NUMBER), +] + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _collect(gen: AsyncIterator[bytes]) -> bytes: + out = b"" + async for chunk in gen: + out += chunk + return out + + +@pytest.mark.asyncio +async def test_json_envelope_with_rows_and_cursor() -> None: + rows = [{"id": "a", "mw": 1.5}, {"id": "b", "mw": 2.0}] + body = await _collect(JsonSerializer().stream(_aiter(rows), COLUMNS, next_cursor="CURSOR")) + parsed = json.loads(body) + assert parsed["next_cursor"] == "CURSOR" + assert parsed["rows"] == [{"id": "a", "mw": 1.5}, {"id": "b", "mw": 2.0}] + + +@pytest.mark.asyncio +async def test_json_empty_result_is_valid() -> None: + body = await _collect(JsonSerializer().stream(_aiter([]), COLUMNS, next_cursor=None)) + parsed = json.loads(body) + assert parsed == {"rows": [], "next_cursor": None} + + +@pytest.mark.asyncio +async def test_json_projects_only_declared_columns() -> None: + rows = [{"id": "a", "mw": 1.5, "secret": "x"}] + body = await _collect(JsonSerializer().stream(_aiter(rows), COLUMNS)) + parsed = json.loads(body) + assert parsed["rows"] == [{"id": "a", "mw": 1.5}] + assert parsed["next_cursor"] is None diff --git a/server/tests/unit/application/api/v1/routes/data/test_params.py b/server/tests/unit/application/api/v1/routes/data/test_params.py new file mode 100644 index 0000000..16c5907 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/test_params.py @@ -0,0 +1,32 @@ +"""Unit tests for table-route request parsing (sort spec). Plan construction +lives on the table-read handlers (tests/unit/domain/data/test_table_read_handlers.py).""" + +import pytest + +from osa.application.api.v1.routes.data._params import parse_sort +from osa.domain.data.model.query_plan import SortDirection +from osa.domain.shared.error import ValidationError + + +def test_parse_sort_empty_is_empty_list() -> None: + assert parse_sort(None) == [] + assert parse_sort("") == [] + + +def test_parse_sort_multi_with_directions() -> None: + specs = parse_sort("created_at:desc,id:desc") + assert [(s.column, s.direction) for s in specs] == [ + ("created_at", SortDirection.DESC), + ("id", SortDirection.DESC), + ] + + +def test_parse_sort_bare_column_defaults_asc() -> None: + specs = parse_sort("mw") + assert specs[0].column == "mw" + assert specs[0].direction == SortDirection.ASC + + +def test_parse_sort_invalid_direction_raises() -> None: + with pytest.raises(ValidationError): + parse_sort("mw:sideways") diff --git a/server/tests/unit/application/api/v1/routes/data/test_streaming.py b/server/tests/unit/application/api/v1/routes/data/test_streaming.py new file mode 100644 index 0000000..5d5c255 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/test_streaming.py @@ -0,0 +1,180 @@ +"""Unit tests for build_table_response — pre-flight, streaming, pagination, cursor.""" + +import base64 +import csv +import gzip +import io +import json +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.application.api.v1.routes.data._streaming import build_table_response +from osa.application.api.v1.routes.data.formats import FORMATS +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import ( + QueryPlan, + SortDirection, + SortSpec, + TableKind, +) +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") +COLUMNS = [ColumnSpec(name="id", type=FieldType.TEXT), ColumnSpec(name="srn", type=FieldType.TEXT)] +JSON_FMT = next(f for f in FORMATS if f.suffix == "") +CSV_FMT = next(f for f in FORMATS if f.suffix == "csv") +GZ_FMT = next(f for f in FORMATS if f.suffix == "csv.gz") + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _raising() -> AsyncIterator[Mapping[str, Any]]: + raise ValueError("boom before first row") + yield # pragma: no cover + + +async def _body(resp) -> bytes: + out = b"" + async for chunk in resp.body_iterator: + out += chunk.encode() if isinstance(chunk, str) else chunk + return out + + +def _plan(limit: int = 50) -> QueryPlan: + return QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, pagination={"limit": limit}) + + +@pytest.mark.asyncio +async def test_streaming_preflight_raises_before_bytes() -> None: + # An error on the first __anext__ must propagate (→ 4xx), not corrupt a stream. + with pytest.raises(ValueError, match="boom"): + await build_table_response(_raising(), CSV_FMT, COLUMNS, _plan()) + + +@pytest.mark.asyncio +async def test_streaming_csv_full_table() -> None: + rows = [{"id": "a", "srn": "s1"}, {"id": "b", "srn": "s2"}] + resp = await build_table_response(_aiter(rows), CSV_FMT, COLUMNS, _plan()) + parsed = list(csv.reader(io.StringIO((await _body(resp)).decode()))) + assert parsed[0] == ["id", "srn"] + assert parsed[1:] == [["a", "s1"], ["b", "s2"]] + + +@pytest.mark.asyncio +async def test_streaming_gzip_roundtrip() -> None: + rows = [{"id": "a", "srn": "s1"}] + resp = await build_table_response(_aiter(rows), GZ_FMT, COLUMNS, _plan()) + text = gzip.decompress(await _body(resp)).decode() + assert "id,srn" in text and "a,s1" in text + + +@pytest.mark.asyncio +async def test_streaming_empty_is_header_only() -> None: + resp = await build_table_response(_aiter([]), CSV_FMT, COLUMNS, _plan()) + parsed = list(csv.reader(io.StringIO((await _body(resp)).decode()))) + assert parsed == [["id", "srn"]] + + +@pytest.mark.asyncio +async def test_paginated_no_more_rows_null_cursor() -> None: + rows = [{"id": "a", "srn": "s1"}, {"id": "b", "srn": "s2"}] + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, _plan(limit=50)) + parsed = json.loads(await _body(resp)) + assert parsed["next_cursor"] is None + assert parsed["rows"] == [{"id": "a", "srn": "s1"}, {"id": "b", "srn": "s2"}] + + +@pytest.mark.asyncio +async def test_paginated_has_more_emits_cursor() -> None: + # limit=2, but 3 rows available → has_more, cursor derived from 2nd row's srn. + rows = [ + {"id": "a", "srn": "s1", "created_at": "2026-01-03"}, + {"id": "b", "srn": "s2", "created_at": "2026-01-02"}, + {"id": "c", "srn": "s3", "created_at": "2026-01-01"}, + ] + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, _plan(limit=2)) + parsed = json.loads(await _body(resp)) + assert len(parsed["rows"]) == 2 + assert parsed["next_cursor"] is not None + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + # default RECORDS sort is created_at desc, tiebreaker srn + assert decoded["s"] == "2026-01-02" + assert decoded["id"] == "s2" + + +@pytest.mark.asyncio +async def test_paginated_records_id_sort_encodes_srn() -> None: + # ``sort=id`` on records aliases to the srn column in the store, so the + # cursor's sort value must be the srn too. Encoding the bare id would + # compare it against the srn column — every SRN sorts after a bare id, so + # the keyset matches all rows and pagination never advances. + rows = [ + {"id": "a", "srn": "urn:osa:localhost:rec:a@1"}, + {"id": "b", "srn": "urn:osa:localhost:rec:b@1"}, + {"id": "c", "srn": "urn:osa:localhost:rec:c@1"}, + ] + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination={"limit": 2}, + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, plan) + parsed = json.loads(await _body(resp)) + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + assert decoded["s"] == "urn:osa:localhost:rec:b@1" + assert decoded["id"] == "urn:osa:localhost:rec:b@1" + + +@pytest.mark.asyncio +async def test_paginated_features_id_sort_encodes_row_id() -> None: + # Feature rows have a real integer id column (no srn key), so ``sort=id`` + # keeps encoding the row's own id. + rows = [ + {"id": 1, "record_srn": "urn:osa:localhost:rec:a@1"}, + {"id": 2, "record_srn": "urn:osa:localhost:rec:a@1"}, + {"id": 3, "record_srn": "urn:osa:localhost:rec:a@1"}, + ] + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="chem_features", + pagination={"limit": 2}, + # default FEATURE sort is id asc + ) + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, plan) + parsed = json.loads(await _body(resp)) + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + assert decoded["s"] == 2 + assert decoded["id"] == 2 + + +@pytest.mark.asyncio +async def test_paginated_feature_hook_column_named_srn_does_not_hijack_tiebreak() -> None: + # A hook may legally declare a data column named "srn" (it's not a feature + # auto column). The cursor tiebreaker must still be the integer row id — + # picking the hook column's string would 400 on the next page when it's + # coerced against the BIGINT id column. + rows = [ + {"id": 1, "record_srn": "urn:osa:localhost:rec:a@1", "srn": "hook-value-1"}, + {"id": 2, "record_srn": "urn:osa:localhost:rec:a@1", "srn": "hook-value-2"}, + {"id": 3, "record_srn": "urn:osa:localhost:rec:a@1", "srn": "hook-value-3"}, + ] + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="chem_features", + pagination={"limit": 2}, + # default FEATURE sort is id asc + ) + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, plan) + parsed = json.loads(await _body(resp)) + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + assert decoded["s"] == 2 + assert decoded["id"] == 2 diff --git a/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py b/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py new file mode 100644 index 0000000..b7ef790 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py @@ -0,0 +1,87 @@ +"""T033 — register_table_routes registers exactly 6 routes with stable op IDs.""" + +import pytest +from fastapi import APIRouter + +from osa.application.api.v1.routes.data.tables import ( + format_key, + path_for, + register_table_routes, +) +from osa.application.api.v1.routes.data.formats import FORMATS + + +def _noop_builder(_fmt): + async def endpoint() -> dict: # pragma: no cover - never called in this test + return {} + + return endpoint + + +def _routes(router: APIRouter): + return [r for r in router.routes if getattr(r, "operation_id", None)] + + +def test_registers_six_routes() -> None: + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + assert len(_routes(router)) == 6 + + +def test_stable_operation_ids_for_records() -> None: + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + ids = {r.operation_id for r in _routes(router)} + assert ids == { + "records_get_json", + "records_post_json", + "records_get_csv", + "records_post_csv", + "records_get_csv_gz", + "records_post_csv_gz", + } + + +def test_stable_operation_ids_for_feature() -> None: + router = APIRouter() + register_table_routes(router, "/{schema}/{feature}", _noop_builder, _noop_builder, "feature") + ids = {r.operation_id for r in _routes(router)} + assert ids == { + "feature_get_json", + "feature_post_json", + "feature_get_csv", + "feature_post_csv", + "feature_get_csv_gz", + "feature_post_csv_gz", + } + + +def test_duplicate_resource_on_same_router_raises() -> None: + # T114: a second registration with the same resource_name on the same router + # would mint colliding operation IDs — caught loudly, not silently shipped. + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + with pytest.raises(ValueError, match="Duplicate operation_id"): + register_table_routes(router, "/{schema}/other", _noop_builder, _noop_builder, "records") + + +def test_distinct_resources_on_same_router_coexist() -> None: + # records + feature share one router in the real app — no false positive. + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + register_table_routes(router, "/{schema}/{feature}", _noop_builder, _noop_builder, "feature") + assert len(_routes(router)) == 12 + + +def test_paths_use_suffix() -> None: + json_fmt = next(f for f in FORMATS if f.suffix == "") + csv_fmt = next(f for f in FORMATS if f.suffix == "csv") + gz_fmt = next(f for f in FORMATS if f.suffix == "csv.gz") + assert path_for("/{schema}/records", json_fmt) == "/{schema}/records" + assert path_for("/{schema}/records", csv_fmt) == "/{schema}/records.csv" + assert path_for("/{schema}/records", gz_fmt) == "/{schema}/records.csv.gz" + + +def test_format_key_mapping() -> None: + keys = {format_key(f) for f in FORMATS} + assert keys == {"json", "csv", "csv_gz"} diff --git a/server/tests/unit/config/test_config.py b/server/tests/unit/config/test_config.py index e5fb1cb..de537b2 100644 --- a/server/tests/unit/config/test_config.py +++ b/server/tests/unit/config/test_config.py @@ -168,3 +168,14 @@ def test_empty_providers_default(self): """auth.providers defaults to empty OrcidConfig.""" cfg = config_from_yaml({}) assert cfg.auth.providers.orcid.client_id == "" + + +class TestDataConfig: + """The /data/ page-size ceiling is operator configuration, not a code constant.""" + + def test_max_page_limit_default(self): + assert config_from_yaml({}).data.max_page_limit == 1000 + + def test_max_page_limit_yaml_override(self): + cfg = config_from_yaml({"data": {"max_page_limit": 250}}) + assert cfg.data.max_page_limit == 250 diff --git a/server/osa/domain/export/service/__init__.py b/server/tests/unit/domain/data/__init__.py similarity index 100% rename from server/osa/domain/export/service/__init__.py rename to server/tests/unit/domain/data/__init__.py diff --git a/server/tests/unit/domain/data/test_data_catalog_service.py b/server/tests/unit/domain/data/test_data_catalog_service.py new file mode 100644 index 0000000..65dcfc8 --- /dev/null +++ b/server/tests/unit/domain/data/test_data_catalog_service.py @@ -0,0 +1,115 @@ +"""DataCatalogService.resolve_table — one owner for "which columns does this table have". + +The records/feature route files previously each held a private helper that +re-derived manifest structure (records resource is named "records"; a feature +resource has kind FEATURE). That knowledge belongs to the catalog service. +""" + +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import ( + ColumnSpec, + ResolvedTable, + SchemaManifest, + TableResource, +) +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import SchemaId + +RECORDS_COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="srn", type=FieldType.TEXT), +] +FEATURE_COLUMNS = [ + ColumnSpec(name="id", type=FieldType.NUMBER), + ColumnSpec(name="mw", type=FieldType.NUMBER), +] + +MANIFEST = SchemaManifest( + id="compound", + version="1.0.0", + srn="urn:osa:localhost:schema:compound@1.0.0", + fields=[], + table_resources=[ + TableResource( + name="records", + kind=TableKind.RECORDS, + columns=RECORDS_COLUMNS, + row_count=2, + formats=["", "csv", "csv.gz"], + ), + TableResource( + name="chem_features", + kind=TableKind.FEATURE, + columns=FEATURE_COLUMNS, + row_count=2, + formats=["", "csv", "csv.gz"], + ), + ], +) + + +class FakeReadStore: + """Minimal DataReadStore fake for catalog reads.""" + + def __init__(self, manifest: SchemaManifest | None = MANIFEST) -> None: + self.manifest = manifest + + def stream_rows(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + raise NotImplementedError + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: + return None + + async def get_node_catalog(self) -> NodeCatalog: + raise NotImplementedError + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | None: + return self.manifest + + async def get_latest_schema_id(self, schema_short_id: str) -> SchemaId | None: + return SchemaId.parse(f"{schema_short_id}@1.0.0") + + +def _service(store: FakeReadStore | None = None) -> DataCatalogService: + return DataCatalogService(read_store=store or FakeReadStore()) + + +@pytest.mark.asyncio +async def test_resolve_table_records() -> None: + resolved = await _service().resolve_table("compound@1.0.0", TableKind.RECORDS) + assert isinstance(resolved, ResolvedTable) + assert resolved.schema_id == SchemaId.parse("compound@1.0.0") + assert resolved.columns == RECORDS_COLUMNS + + +@pytest.mark.asyncio +async def test_resolve_table_feature() -> None: + resolved = await _service().resolve_table( + "compound@1.0.0", TableKind.FEATURE, feature_name="chem_features" + ) + assert resolved.columns == FEATURE_COLUMNS + + +@pytest.mark.asyncio +async def test_resolve_table_unknown_feature_404s() -> None: + with pytest.raises(NotFoundError) as exc: + await _service().resolve_table("compound@1.0.0", TableKind.FEATURE, feature_name="nope") + assert exc.value.code == "table_not_found" + + +@pytest.mark.asyncio +async def test_resolve_table_kind_must_match() -> None: + # A resource named "records" exists, but with kind RECORDS — asking for a + # FEATURE of that name must not match it. + with pytest.raises(NotFoundError): + await _service().resolve_table("compound@1.0.0", TableKind.FEATURE, feature_name="records") diff --git a/server/tests/unit/domain/data/test_data_query_service.py b/server/tests/unit/domain/data/test_data_query_service.py new file mode 100644 index 0000000..5966285 --- /dev/null +++ b/server/tests/unit/domain/data/test_data_query_service.py @@ -0,0 +1,143 @@ +"""Unit tests for DataQueryService — filter-bounds validation + delegation. + +Uses a fake DataReadStore (no DB), mirroring the discovery service tests. +""" + +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any + +import pytest + +from osa.domain.data.model.filter import And, FilterOperator, Predicate +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.service.data_query import DataQueryService +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +@dataclass +class FakeDataConfig: + """The /data/ filter-bound knobs DataQueryService reads (config.data.*).""" + + max_filter_depth: int = 10 + max_predicates: int = 200 + max_feature_joins: int = 10 + + +@dataclass +class FakeConfig: + """Minimal stand-in exposing only ``config.data`` (avoids full env setup).""" + + data: FakeDataConfig = field(default_factory=FakeDataConfig) + + +class FakeReadStore: + def __init__(self, rows: list[Mapping[str, Any]]) -> None: + self._rows = rows + self.received_plan: QueryPlan | None = None + self.received_timeout: timedelta | None = None + + async def stream_rows( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + self.received_plan = plan + self.received_timeout = timeout + for row in self._rows: + yield row + + async def get_record_by_id(self, id, version): # pragma: no cover + return None + + async def get_node_catalog(self): # pragma: no cover + ... + + async def get_schema_manifest(self, schema_id): # pragma: no cover + return None + + async def get_latest_schema_id(self, schema_short_id): # pragma: no cover + return None + + +def _service(rows=None) -> tuple[DataQueryService, FakeReadStore]: + store = FakeReadStore(rows or []) + return DataQueryService(read_store=store, config=FakeConfig()), store + + +def _pred(field: str = "metadata.mw") -> Predicate: + return Predicate(field=field, op=FilterOperator.EQ, value=1) + + +async def _drain(gen: AsyncIterator[Mapping[str, Any]]) -> list[Mapping[str, Any]]: + return [row async for row in gen] + + +@pytest.mark.asyncio +async def test_stream_records_delegates_to_store() -> None: + service, store = _service([{"id": "a"}, {"id": "b"}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + rows = await _drain(service.stream_records(plan)) + assert rows == [{"id": "a"}, {"id": "b"}] + assert store.received_plan is plan + + +@pytest.mark.asyncio +async def test_stream_records_rejects_feature_plan() -> None: + service, _ = _service() + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + with pytest.raises(ValidationError): + await _drain(service.stream_records(plan)) + + +@pytest.mark.asyncio +async def test_filter_depth_bound_enforced() -> None: + service, _ = _service() + service.config.data.max_filter_depth = 1 + deep = And(operands=[_pred(), And(operands=[_pred(), _pred()])]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, filter=deep) + with pytest.raises(ValidationError) as exc: + await _drain(service.stream_records(plan)) + assert exc.value.code == "filter_depth_exceeded" + + +@pytest.mark.asyncio +async def test_filter_predicate_count_bound_enforced() -> None: + service, _ = _service() + service.config.data.max_predicates = 1 + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + filter=And(operands=[_pred(), _pred()]), + ) + with pytest.raises(ValidationError) as exc: + await _drain(service.stream_records(plan)) + assert exc.value.code == "filter_predicates_exceeded" + + +@pytest.mark.asyncio +async def test_no_filter_streams_cleanly() -> None: + service, _ = _service([{"id": "x"}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + assert await _drain(service.stream_records(plan)) == [{"id": "x"}] + + +@pytest.mark.asyncio +async def test_stream_records_passes_timeout_to_store() -> None: + # The statement timeout is an execution budget owned by the caller (the + # route's format registry); the service forwards it to the port so the + # adapter — the only layer holding a SQL session — can apply it. + service, store = _service([{"id": "a"}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + await _drain(service.stream_records(plan, timeout=timedelta(seconds=30))) + assert store.received_timeout == timedelta(seconds=30) + + +@pytest.mark.asyncio +async def test_stream_features_passes_timeout_to_store() -> None: + service, store = _service([{"id": 1}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + await _drain(service.stream_features(plan, timeout=timedelta(minutes=30))) + assert store.received_timeout == timedelta(minutes=30) diff --git a/server/tests/unit/domain/data/test_query_plan.py b/server/tests/unit/domain/data/test_query_plan.py new file mode 100644 index 0000000..963abf8 --- /dev/null +++ b/server/tests/unit/domain/data/test_query_plan.py @@ -0,0 +1,144 @@ +"""T029 — QueryPlan validation rules and per-kind default sorts.""" + +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError as PydanticValidationError + +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, + decode_cursor, + encode_cursor, +) +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def test_records_plan_defaults_to_created_at_id_desc() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + assert [(s.column, s.direction) for s in plan.sort] == [ + ("created_at", SortDirection.DESC), + ("id", SortDirection.DESC), + ] + + +def test_feature_plan_defaults_to_id_asc() -> None: + plan = QueryPlan( + schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="chemical_features" + ) + assert [(s.column, s.direction) for s in plan.sort] == [("id", SortDirection.ASC)] + + +def test_feature_kind_requires_feature_name() -> None: + with pytest.raises(PydanticValidationError): + QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE) + + +def test_records_kind_rejects_feature_name() -> None: + with pytest.raises(PydanticValidationError): + QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, feature_name="x") + + +def test_limit_above_max_clamps_to_max() -> None: + # Clamp, don't reject: a consumer asking for "everything" with a big + # number gets the max page, not a 422. The clamp *policy* lives HERE, on + # the canonical model; the *bound* comes from config (DataConfig). + assert PaginationParams.clamped(limit=5000, max_limit=1000).limit == 1000 + assert PaginationParams.clamped(limit=5000, max_limit=200).limit == 200 + + +def test_limit_below_one_clamps_to_one() -> None: + assert PaginationParams.clamped(limit=0, max_limit=1000).limit == 1 + assert PaginationParams.clamped(limit=-5, max_limit=1000).limit == 1 + + +def test_clamped_preserves_cursor() -> None: + p = PaginationParams.clamped(cursor=PaginationCursor(value="CUR"), limit=10, max_limit=1000) + assert str(p.cursor) == "CUR" + assert p.limit == 10 + + +def test_limit_default_is_50() -> None: + assert PaginationParams().limit == 50 + + +def test_cursor_roundtrip() -> None: + enc = encode_cursor("2026-01-01T00:00:00", "abc") + decoded = decode_cursor(enc) + assert decoded == {"s": "2026-01-01T00:00:00", "id": "abc"} + + +def test_decode_malformed_cursor_raises() -> None: + with pytest.raises(ValueError): + decode_cursor("not-valid-base64-json!!!") + + +def test_encode_cursor_accepts_datetime_sort_value() -> None: + # Feature rows reach the cursor encoder as raw DB mappings, so a + # created_at sort value is a datetime, not a pre-rendered string. + ts = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + decoded = decode_cursor(encode_cursor(ts, 7)) + assert datetime.fromisoformat(decoded["s"]) == ts + assert decoded["id"] == 7 + + +# --------------------------------------------------------------------------- # +# Keyset — single owner of the pagination contract (tiebreak + sort aliasing) +# --------------------------------------------------------------------------- # + + +def test_records_keyset_defaults() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + assert plan.keyset.tiebreak_column == "srn" + assert plan.keyset.sort_column == "created_at" + + +def test_records_keyset_id_sort_aliases_to_tiebreak() -> None: + # ``sort=id`` on records sorts by the srn column (the PK; a bare id is not + # unique across versions), so the keyset's sort column IS the tiebreaker. + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + assert plan.keyset.sort_column == "srn" + assert plan.keyset.tiebreak_column == "srn" + + +def test_feature_keyset_defaults() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + assert plan.keyset.tiebreak_column == "id" + assert plan.keyset.sort_column == "id" + + +def test_keyset_cursor_from_row_records_default_sort() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + cursor = plan.keyset.cursor_from_row({"id": "a", "srn": "s1", "created_at": "2026-01-02"}) + assert decode_cursor(cursor) == {"s": "2026-01-02", "id": "s1"} + + +def test_keyset_cursor_from_row_records_id_sort_encodes_srn() -> None: + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + cursor = plan.keyset.cursor_from_row({"id": "a", "srn": "urn:osa:localhost:rec:a@1"}) + assert decode_cursor(cursor) == { + "s": "urn:osa:localhost:rec:a@1", + "id": "urn:osa:localhost:rec:a@1", + } + + +def test_keyset_cursor_from_row_feature_hook_column_named_srn_is_inert() -> None: + # A hook may declare a data column named "srn"; the feature tiebreaker is + # still the integer row id. + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + cursor = plan.keyset.cursor_from_row({"id": 2, "srn": "hook-value-2"}) + assert decode_cursor(cursor) == {"s": 2, "id": 2} diff --git a/server/tests/unit/domain/data/test_table_read_handlers.py b/server/tests/unit/domain/data/test_table_read_handlers.py new file mode 100644 index 0000000..9a9da74 --- /dev/null +++ b/server/tests/unit/domain/data/test_table_read_handlers.py @@ -0,0 +1,124 @@ +"""Data-domain query handlers — the route's single entry point per request. + +Restores the documented Router → QueryHandler → Service layering on the /data/ +surface (and with it the __auth__ seam the deferred private-schema auth model +will need). Handlers own plan construction: table resolution, limit clamping +against config, and default sorts — routes only parse HTTP. +""" + +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec, ResolvedTable +from osa.domain.data.model.query_plan import ( + QueryPlan, + SortDirection, + SortSpec, + TableKind, +) +from osa.domain.data.query.read_table import ( + ReadFeatureTable, + ReadFeatureTableHandler, + ReadRecordsTable, + ReadRecordsTableHandler, +) +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import SchemaId + +SCHEMA_ID = SchemaId.parse("compound@1.0.0") +COLUMNS = [ColumnSpec(name="id", type=FieldType.TEXT)] + + +@dataclass +class FakeDataConfig: + max_filter_depth: int = 10 + max_predicates: int = 200 + max_feature_joins: int = 10 + max_page_limit: int = 1000 + + +@dataclass +class FakeConfig: + data: FakeDataConfig = field(default_factory=FakeDataConfig) + + +class FakeCatalogService: + def __init__(self) -> None: + self.resolved_with: tuple[Any, ...] | None = None + + async def resolve_table(self, schema, table_kind, feature_name=None) -> ResolvedTable: + self.resolved_with = (schema, table_kind, feature_name) + return ResolvedTable(schema_id=SCHEMA_ID, columns=COLUMNS) + + +class FakeQueryService: + def __init__(self) -> None: + self.received: tuple[QueryPlan, timedelta | None] | None = None + + async def _rows(self) -> AsyncIterator[Mapping[str, Any]]: + yield {"id": "a"} + + def stream_records(self, plan: QueryPlan, timeout: timedelta | None = None): + self.received = (plan, timeout) + return self._rows() + + def stream_features(self, plan: QueryPlan, timeout: timedelta | None = None): + self.received = (plan, timeout) + return self._rows() + + +def _records_handler() -> tuple[ReadRecordsTableHandler, FakeCatalogService, FakeQueryService]: + catalog, query = FakeCatalogService(), FakeQueryService() + handler = ReadRecordsTableHandler( + catalog_service=catalog, query_service=query, config=FakeConfig() + ) + return handler, catalog, query + + +@pytest.mark.asyncio +async def test_records_handler_resolves_and_streams() -> None: + handler, catalog, query = _records_handler() + result = await handler.run(ReadRecordsTable(schema="compound@1.0.0")) + + assert catalog.resolved_with == ("compound@1.0.0", TableKind.RECORDS, None) + assert result.columns == COLUMNS + assert result.plan.schema_id == SCHEMA_ID + assert result.plan.table_kind == TableKind.RECORDS + assert [r async for r in result.rows] == [{"id": "a"}] + + +@pytest.mark.asyncio +async def test_records_handler_clamps_limit_to_config() -> None: + handler, _, _ = _records_handler() + handler.config.data.max_page_limit = 100 + result = await handler.run(ReadRecordsTable(schema="compound@1.0.0", limit=5000)) + assert result.plan.pagination.limit == 100 + + +@pytest.mark.asyncio +async def test_records_handler_forwards_timeout_and_sort() -> None: + handler, _, query = _records_handler() + sort = [SortSpec(column="mw", direction=SortDirection.DESC)] + result = await handler.run( + ReadRecordsTable(schema="compound@1.0.0", sort=sort, timeout=timedelta(seconds=30)) + ) + plan, timeout = query.received + assert plan is result.plan + assert timeout == timedelta(seconds=30) + assert plan.sort[0].column == "mw" + + +@pytest.mark.asyncio +async def test_feature_handler_resolves_feature_and_builds_feature_plan() -> None: + catalog, query = FakeCatalogService(), FakeQueryService() + handler = ReadFeatureTableHandler( + catalog_service=catalog, query_service=query, config=FakeConfig() + ) + result = await handler.run(ReadFeatureTable(schema="compound@1.0.0", feature="chem_features")) + assert catalog.resolved_with == ("compound@1.0.0", TableKind.FEATURE, "chem_features") + assert result.plan.table_kind == TableKind.FEATURE + assert result.plan.feature_name == "chem_features" diff --git a/server/tests/unit/domain/deposition/test_hook_reserved_name.py b/server/tests/unit/domain/deposition/test_hook_reserved_name.py new file mode 100644 index 0000000..b20650c --- /dev/null +++ b/server/tests/unit/domain/deposition/test_hook_reserved_name.py @@ -0,0 +1,36 @@ +"""T065 — HookDefinition rejects reserved hook names (records, datasets).""" + +import pytest + +from osa.domain.shared.error import ReservedNameError +from osa.domain.shared.model.hook import ( + ColumnDef, + HookDefinition, + OciConfig, + TableFeatureSpec, +) + + +def _hook(name: str) -> HookDefinition: + return HookDefinition( + name=name, + runtime=OciConfig(image="ghcr.io/example/hook", digest="sha256:abc123"), + feature=TableFeatureSpec( + cardinality="one", + columns=[ColumnDef(name="score", json_type="number", required=True)], + ), + ) + + +@pytest.mark.parametrize("reserved", ["records", "datasets"]) +def test_hook_rejects_reserved_name(reserved: str) -> None: + with pytest.raises(ReservedNameError) as exc: + _hook(reserved) + assert exc.value.code == "reserved_name" + assert exc.value.kind == "hook" + assert exc.value.name == reserved + + +def test_hook_allows_non_reserved_name() -> None: + hook = _hook("chemical_features") + assert hook.name == "chemical_features" diff --git a/server/tests/unit/domain/discovery/__init__.py b/server/tests/unit/domain/discovery/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/tests/unit/domain/discovery/test_discovery_service.py b/server/tests/unit/domain/discovery/test_discovery_service.py deleted file mode 100644 index b68dbba..0000000 --- a/server/tests/unit/domain/discovery/test_discovery_service.py +++ /dev/null @@ -1,476 +0,0 @@ -"""Tests for DiscoveryService — FilterExpr validation, operator validation, delegation.""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - ColumnInfo, - FeatureCatalogEntry, - FeatureRow, - FilterOperator, - Predicate, - RecordSummary, - SortOrder, - decode_cursor, - encode_cursor, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import NotFoundError, ValidationError -from osa.domain.shared.model.srn import RecordSRN, SchemaId - - -SCHEMA_SRN = SchemaId.parse("bio-sample@1.0.0") - - -def _config() -> Config: - # Build a Config with minimal auth — tests don't hit JWT paths - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) # Test-only secret - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -@pytest.fixture -def mock_read_store() -> AsyncMock: - store = AsyncMock() - store.search_records.return_value = [] - return store - - -@pytest.fixture -def mock_field_reader() -> AsyncMock: - reader = AsyncMock() - reader.get_all_field_types.return_value = { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TERM, - "published_date": FieldType.DATE, - "is_public": FieldType.BOOLEAN, - "homepage": FieldType.URL, - } - reader.get_fields_for_schema.return_value = { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TERM, - "published_date": FieldType.DATE, - "is_public": FieldType.BOOLEAN, - "homepage": FieldType.URL, - } - return reader - - -@pytest.fixture -def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: - return DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - - -def _eq(field: str, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=FilterOperator.EQ, value=value) - - -class TestSearchRecordsValidation: - async def test_rejects_unknown_filter_field(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="Unknown metadata field 'bogus'"): - await service.search_records( - filter_expr=_eq("bogus", "x"), - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_rejects_invalid_operator_for_type(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="not valid"): - await service.search_records( - filter_expr=Predicate( - field=MetadataFieldRef(field="resolution"), - op=FilterOperator.CONTAINS, - value="x", - ), - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_rejects_unknown_sort_field(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="Unknown sort field"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="nonexistent", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_accepts_published_at_sort(self, service: DiscoveryService) -> None: - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert result.results == [] - - async def test_accepts_metadata_field_sort(self, service: DiscoveryService) -> None: - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="resolution", - order=SortOrder.ASC, - cursor=None, - limit=20, - ) - assert result.results == [] - - async def test_rejects_limit_too_low(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="limit"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=0, - ) - - async def test_rejects_limit_too_high(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="limit"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=101, - ) - - async def test_raises_not_found_for_unknown_schema(self, mock_read_store: AsyncMock) -> None: - """Pinning an unregistered schema must raise NotFoundError, not silently - fall through to an unscoped query that returns cross-schema records.""" - empty_reader = AsyncMock() - empty_reader.get_fields_for_schema.return_value = {} - svc = DiscoveryService( - read_store=mock_read_store, - field_reader=empty_reader, - config=_config(), - ) - - with pytest.raises(NotFoundError, match="Schema not found"): - await svc.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - mock_read_store.search_records.assert_not_called() - - async def test_search_features_raises_not_found_for_unknown_schema( - self, mock_read_store: AsyncMock - ) -> None: - """search_features must also guard against unknown schema pins.""" - mock_read_store.get_feature_table_schema.return_value = FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ColumnInfo(name="score", type="number", required=False)], - record_count=0, - ) - empty_reader = AsyncMock() - empty_reader.get_fields_for_schema.return_value = {} - svc = DiscoveryService( - read_store=mock_read_store, - field_reader=empty_reader, - config=_config(), - ) - - with pytest.raises(NotFoundError, match="Schema not found"): - await svc.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=SCHEMA_SRN, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - mock_read_store.search_features.assert_not_called() - - async def test_rejects_q_when_no_text_fields(self, mock_read_store: AsyncMock) -> None: - no_text_reader = AsyncMock() - no_text_reader.get_all_field_types.return_value = {"resolution": FieldType.NUMBER} - no_text_reader.get_fields_for_schema.return_value = {"resolution": FieldType.NUMBER} - svc = DiscoveryService( - read_store=mock_read_store, - field_reader=no_text_reader, - config=_config(), - ) - - with pytest.raises(ValidationError, match="Free-text search is unavailable"): - await svc.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q="kinase", - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestSearchRecordsDelegation: - async def test_delegates_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - await service.search_records( - filter_expr=_eq("method", "X-ray"), - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - mock_read_store.search_records.assert_called_once() - call_kwargs = mock_read_store.search_records.call_args - assert call_kwargs.kwargs["filter_expr"] is not None - assert call_kwargs.kwargs["q"] is None - assert call_kwargs.kwargs["sort"] == "published_at" - assert call_kwargs.kwargs["limit"] == 21 # N+1 - - async def test_decodes_cursor( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - cursor = encode_cursor("2026-01-01", "urn:osa:localhost:rec:abc@1") - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=cursor, - limit=20, - ) - - call_kwargs = mock_read_store.search_records.call_args - decoded = call_kwargs.kwargs["cursor"] - assert decoded["s"] == "2026-01-01" - assert decoded["id"] == "urn:osa:localhost:rec:abc@1" - - async def test_invalid_cursor_raises(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="cursor"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor="not-a-cursor!!!", - limit=20, - ) - - async def test_encodes_next_cursor_from_results( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - ts = datetime(2026, 1, 1, tzinfo=UTC) - mock_read_store.search_records.return_value = [ - RecordSummary(srn=srn, published_at=ts, metadata={"title": f"r{i}"}) for i in range(2) - ] - - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=1, - ) - - assert result.has_more is True - assert result.cursor is not None - assert len(result.results) == 1 - decoded = decode_cursor(result.cursor) - assert decoded["id"] == str(srn) - - async def test_no_cursor_when_no_more_results( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - mock_read_store.search_records.return_value = [] - - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - assert result.cursor is None - assert result.has_more is False - - -class TestSchemaRequiredGuards: - """With the JSONB filter fallback removed, any query that resolves against - metadata fields must pin a schema.""" - - async def test_metadata_predicate_without_schema_raises( - self, service: DiscoveryService - ) -> None: - with pytest.raises(ValidationError) as exc: - await service.search_records( - filter_expr=_eq("title", "x"), - schema_id=None, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert exc.value.code == "schema_required_for_metadata_query" - - async def test_non_default_sort_without_schema_raises(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError) as exc: - await service.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - q=None, - sort="resolution", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert exc.value.code == "schema_required_for_metadata_sort" - - async def test_q_without_schema_raises(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError) as exc: - await service.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - q="kinase", - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert exc.value.code == "schema_required_for_free_text_search" - - async def test_plain_listing_without_schema_succeeds(self, service: DiscoveryService) -> None: - """No filter, default sort, no q → unscoped listing is allowed.""" - result = await service.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert result.results == [] - - -class TestFilterBounds: - async def test_depth_exceeded_raises(self, service: DiscoveryService) -> None: - # Build a nest of AND that exceeds the default depth (10) - leaf = _eq("title", "r") - tree = leaf - for _ in range(11): - tree = And(operands=[tree, leaf]) - - with pytest.raises(ValidationError, match="filter_depth_exceeded|depth"): - await service.search_records( - filter_expr=tree, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestFeatureCursorEncoding: - async def test_cursor_encodes_row_id( - self, mock_read_store: AsyncMock, mock_field_reader: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_read_store.get_feature_table_schema.return_value = FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ColumnInfo(name="score", type="number", required=True)], - record_count=0, - ) - mock_read_store.search_features.return_value = [ - FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66}), - FeatureRow(row_id=43, record_srn=srn, data={"score": 6.0}), - ] - - service = DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - result = await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=1, - ) - - assert result.has_more is True - assert result.cursor is not None - assert len(result.rows) == 1 - decoded = decode_cursor(result.cursor) - assert decoded["id"] == 42 - assert decoded["s"] == 7.66 diff --git a/server/tests/unit/domain/discovery/test_get_feature_catalog.py b/server/tests/unit/domain/discovery/test_get_feature_catalog.py deleted file mode 100644 index f6197e3..0000000 --- a/server/tests/unit/domain/discovery/test_get_feature_catalog.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Tests for GetFeatureCatalogHandler and DiscoveryService.get_feature_catalog().""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.domain.discovery.model.value import ColumnInfo, FeatureCatalogEntry -from osa.domain.discovery.query.get_feature_catalog import ( - GetFeatureCatalog, - GetFeatureCatalogHandler, - GetFeatureCatalogResult, -) -from osa.config import Config -from osa.domain.discovery.service.discovery import DiscoveryService - - -def _config() -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -@pytest.fixture -def mock_read_store() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def mock_field_reader() -> AsyncMock: - reader = AsyncMock() - reader.get_all_field_types.return_value = {} - reader.get_fields_for_schema.return_value = {} - return reader - - -@pytest.fixture -def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: - return DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - - -class TestGetFeatureCatalogHandler: - async def test_public_auth_gate(self) -> None: - from osa.domain.shared.authorization.gate import Public - - assert isinstance(GetFeatureCatalogHandler.__auth__, Public) - - async def test_delegates_to_service(self) -> None: - mock_service = AsyncMock() - from osa.domain.discovery.model.value import FeatureCatalog - - mock_service.get_feature_catalog.return_value = FeatureCatalog(tables=[]) - - handler = GetFeatureCatalogHandler(discovery_service=mock_service) - result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) - - assert result.tables == [] - mock_service.get_feature_catalog.assert_called_once() - - async def test_returns_correct_structure(self) -> None: - mock_service = AsyncMock() - from osa.domain.discovery.model.value import FeatureCatalog - - mock_service.get_feature_catalog.return_value = FeatureCatalog( - tables=[ - FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ - ColumnInfo(name="score", type="number", required=True), - ColumnInfo(name="volume", type="number", required=True), - ], - record_count=142, - ) - ] - ) - - handler = GetFeatureCatalogHandler(discovery_service=mock_service) - result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) - - assert len(result.tables) == 1 - assert result.tables[0]["hook_name"] == "detect_pockets" - assert result.tables[0]["record_count"] == 142 - assert len(result.tables[0]["columns"]) == 2 - - -class TestDiscoveryServiceGetFeatureCatalog: - async def test_delegates_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - mock_read_store.get_feature_catalog.return_value = [] - result = await service.get_feature_catalog() - assert result.tables == [] - mock_read_store.get_feature_catalog.assert_called_once() - - async def test_returns_entries_from_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - entries = [ - FeatureCatalogEntry( - hook_name="test_hook", - columns=[ColumnInfo(name="x", type="number", required=True)], - record_count=10, - ) - ] - mock_read_store.get_feature_catalog.return_value = entries - - result = await service.get_feature_catalog() - assert len(result.tables) == 1 - assert result.tables[0].hook_name == "test_hook" diff --git a/server/tests/unit/domain/discovery/test_search_features.py b/server/tests/unit/domain/discovery/test_search_features.py deleted file mode 100644 index b89494b..0000000 --- a/server/tests/unit/domain/discovery/test_search_features.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Tests for SearchFeaturesHandler and DiscoveryService.search_features().""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import FeatureFieldRef -from osa.domain.discovery.model.value import ( - ColumnInfo, - FeatureCatalogEntry, - FeatureRow, - FeatureSearchResult, - FilterOperator, - Predicate, - SortOrder, -) -from osa.domain.discovery.query.search_features import ( - SearchFeatures, - SearchFeaturesHandler, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.error import NotFoundError, ValidationError -from osa.domain.shared.model.srn import RecordSRN - - -def _make_catalog_entry() -> FeatureCatalogEntry: - return FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ - ColumnInfo(name="score", type="number", required=True), - ColumnInfo(name="volume", type="number", required=True), - ColumnInfo(name="label", type="string", required=False), - ColumnInfo(name="is_active", type="boolean", required=False), - ], - record_count=10, - ) - - -def _config() -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -def _predicate(hook: str, column: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=FeatureFieldRef(hook=hook, column=column), op=op, value=value) - - -@pytest.fixture -def mock_read_store() -> AsyncMock: - store = AsyncMock() - store.get_feature_table_schema.return_value = _make_catalog_entry() - store.search_features.return_value = [] - return store - - -@pytest.fixture -def mock_field_reader() -> AsyncMock: - reader = AsyncMock() - reader.get_all_field_types.return_value = {} - reader.get_fields_for_schema.return_value = {} - return reader - - -@pytest.fixture -def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: - return DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - - -class TestSearchFeaturesHandler: - async def test_public_auth_gate(self) -> None: - from osa.domain.shared.authorization.gate import Public - - assert isinstance(SearchFeaturesHandler.__auth__, Public) - - async def test_delegates_to_service(self) -> None: - mock_service = AsyncMock() - mock_service.search_features.return_value = FeatureSearchResult( - rows=[], cursor=None, has_more=False - ) - - handler = SearchFeaturesHandler(discovery_service=mock_service) - await handler.run(SearchFeatures(hook_name="detect_pockets")) - mock_service.search_features.assert_called_once() - - async def test_invalid_record_srn_raises_validation_error(self) -> None: - mock_service = AsyncMock() - handler = SearchFeaturesHandler(discovery_service=mock_service) - - with pytest.raises(ValidationError, match="not an OSA SRN") as exc_info: - await handler.run(SearchFeatures(hook_name="detect_pockets", record_srn="not-a-srn")) - - assert exc_info.value.field == "record_srn" - mock_service.search_features.assert_not_called() - - async def test_maps_rows_with_record_srn(self) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_service = AsyncMock() - mock_service.search_features.return_value = FeatureSearchResult( - rows=[FeatureRow(row_id=1, record_srn=srn, data={"score": 7.66})], - cursor=None, - has_more=False, - ) - - handler = SearchFeaturesHandler(discovery_service=mock_service) - result = await handler.run(SearchFeatures(hook_name="detect_pockets")) - - assert result.rows[0]["record_srn"] == str(srn) - assert result.rows[0]["score"] == 7.66 - - -class TestDiscoveryServiceSearchFeatures: - async def test_raises_not_found_for_unknown_hook( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - mock_read_store.get_feature_table_schema.return_value = None - - with pytest.raises(NotFoundError): - await service.search_features( - hook_name="unknown_hook", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_rejects_unknown_column(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="bogus"): - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "bogus", FilterOperator.EQ, 1), - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_validates_operator_for_number_column(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="contains"): - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "score", FilterOperator.CONTAINS, "x"), - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_accepts_string_contains_operator(self, service: DiscoveryService) -> None: - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "label", FilterOperator.CONTAINS, "test"), - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_passes_record_srn_filter( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=srn, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - call_kwargs = mock_read_store.search_features.call_args - assert call_kwargs.kwargs["record_srn"] == srn - - async def test_delegates_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "score", FilterOperator.GTE, 6.0), - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - mock_read_store.search_features.assert_called_once() - call_kwargs = mock_read_store.search_features.call_args - assert call_kwargs.kwargs["hook_name"] == "detect_pockets" - assert call_kwargs.kwargs["filter_expr"] is not None - - -class TestSearchFeaturesPagination: - async def test_has_more_false_when_exactly_limit_rows( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_read_store.search_features.return_value = [ - FeatureRow(row_id=i, record_srn=srn, data={"score": float(i)}) for i in range(3) - ] - - result = await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=3, - ) - - assert result.has_more is False - assert result.cursor is None - assert len(result.rows) == 3 - - async def test_has_more_true_when_more_than_limit_rows( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_read_store.search_features.return_value = [ - FeatureRow(row_id=i, record_srn=srn, data={"score": float(i)}) for i in range(4) - ] - - result = await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=3, - ) - - assert result.has_more is True - assert result.cursor is not None - assert len(result.rows) == 3 - - async def test_passes_limit_plus_one_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - call_kwargs = mock_read_store.search_features.call_args - assert call_kwargs.kwargs["limit"] == 51 diff --git a/server/tests/unit/domain/discovery/test_search_records.py b/server/tests/unit/domain/discovery/test_search_records.py deleted file mode 100644 index b9a7fbd..0000000 --- a/server/tests/unit/domain/discovery/test_search_records.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Tests for SearchRecordsHandler — auth gate, delegation, result mapping.""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock - -import pytest - -from osa.domain.discovery.model.value import RecordSearchResult, RecordSummary, SortOrder -from osa.domain.discovery.query.search_records import ( - SearchRecords, - SearchRecordsHandler, - SearchRecordsResult, -) -from osa.domain.shared.model.srn import RecordSRN - - -@pytest.fixture -def mock_service() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def handler(mock_service: AsyncMock) -> SearchRecordsHandler: - return SearchRecordsHandler(discovery_service=mock_service) - - -class TestSearchRecordsHandler: - async def test_public_auth_gate(self, handler: SearchRecordsHandler) -> None: - from osa.domain.shared.authorization.gate import Public - - assert isinstance(handler.__auth__, Public) - - async def test_delegates_to_service( - self, handler: SearchRecordsHandler, mock_service: AsyncMock - ) -> None: - mock_service.search_records.return_value = RecordSearchResult( - results=[], cursor=None, has_more=False - ) - cmd = SearchRecords() - await handler.run(cmd) - - mock_service.search_records.assert_called_once_with( - filter_expr=None, - schema_id=None, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_maps_results( - self, handler: SearchRecordsHandler, mock_service: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - ts = datetime(2026, 1, 1, tzinfo=UTC) - mock_service.search_records.return_value = RecordSearchResult( - results=[RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"})], - cursor="abc123", - has_more=False, - ) - - result: SearchRecordsResult = await handler.run(SearchRecords()) - - assert result.cursor == "abc123" - assert result.has_more is False - assert result.results[0]["srn"] == str(srn) - assert result.results[0]["metadata"] == {"title": "Test"} diff --git a/server/tests/unit/domain/discovery/test_value.py b/server/tests/unit/domain/discovery/test_value.py deleted file mode 100644 index 18518f5..0000000 --- a/server/tests/unit/domain/discovery/test_value.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Tests for discovery domain value objects — cursor helpers and VALID_OPERATORS.""" - -import pytest - -from osa.domain.discovery.model.value import ( - VALID_OPERATORS, - FilterOperator, - decode_cursor, - encode_cursor, -) -from osa.domain.semantics.model.value import FieldType - - -class TestCursorRoundTrip: - def test_round_trip_string_values(self) -> None: - cursor = encode_cursor("2026-01-01", "urn:osa:localhost:rec:abc@1") - decoded = decode_cursor(cursor) - assert decoded["s"] == "2026-01-01" - assert decoded["id"] == "urn:osa:localhost:rec:abc@1" - - def test_round_trip_numeric_values(self) -> None: - cursor = encode_cursor(7.66, 123) - decoded = decode_cursor(cursor) - assert decoded["s"] == 7.66 - assert decoded["id"] == 123 - - def test_round_trip_none_sort_value(self) -> None: - cursor = encode_cursor(None, "urn:osa:localhost:rec:abc@1") - decoded = decode_cursor(cursor) - assert decoded["s"] is None - assert decoded["id"] == "urn:osa:localhost:rec:abc@1" - - -class TestDecodeCursorErrors: - def test_malformed_base64(self) -> None: - with pytest.raises(ValueError, match="Malformed cursor"): - decode_cursor("not-valid-base64!!!") - - def test_invalid_json(self) -> None: - import base64 - - bad = base64.urlsafe_b64encode(b"not json").decode() - with pytest.raises(ValueError, match="Malformed cursor"): - decode_cursor(bad) - - def test_missing_s_key(self) -> None: - import base64 - import json - - bad = base64.urlsafe_b64encode(json.dumps({"id": "x"}).encode()).decode() - with pytest.raises(ValueError, match="'s' and 'id'"): - decode_cursor(bad) - - def test_missing_id_key(self) -> None: - import base64 - import json - - bad = base64.urlsafe_b64encode(json.dumps({"s": 1}).encode()).decode() - with pytest.raises(ValueError, match="'s' and 'id'"): - decode_cursor(bad) - - def test_non_dict_payload(self) -> None: - import base64 - import json - - bad = base64.urlsafe_b64encode(json.dumps([1, 2]).encode()).decode() - with pytest.raises(ValueError, match="'s' and 'id'"): - decode_cursor(bad) - - -class TestValidOperators: - def test_text_operators_include_basics(self) -> None: - ops = VALID_OPERATORS[FieldType.TEXT] - assert FilterOperator.EQ in ops - assert FilterOperator.CONTAINS in ops - assert FilterOperator.IN in ops - assert FilterOperator.NEQ in ops - - def test_url_operators_include_basics(self) -> None: - ops = VALID_OPERATORS[FieldType.URL] - assert FilterOperator.EQ in ops - assert FilterOperator.CONTAINS in ops - assert FilterOperator.IN in ops - - def test_number_operators_support_ordering(self) -> None: - ops = VALID_OPERATORS[FieldType.NUMBER] - assert FilterOperator.EQ in ops - assert FilterOperator.GT in ops - assert FilterOperator.GTE in ops - assert FilterOperator.LT in ops - assert FilterOperator.LTE in ops - - def test_date_operators_support_ordering(self) -> None: - ops = VALID_OPERATORS[FieldType.DATE] - assert FilterOperator.GTE in ops - assert FilterOperator.LTE in ops - - def test_boolean_operators(self) -> None: - assert FilterOperator.EQ in VALID_OPERATORS[FieldType.BOOLEAN] - assert FilterOperator.IS_NULL in VALID_OPERATORS[FieldType.BOOLEAN] - - def test_term_operators(self) -> None: - assert FilterOperator.EQ in VALID_OPERATORS[FieldType.TERM] - - def test_all_field_types_have_operators(self) -> None: - for ft in FieldType: - assert ft in VALID_OPERATORS, f"Missing operators for {ft}" diff --git a/server/tests/unit/domain/index/__init__.py b/server/tests/unit/domain/index/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/tests/unit/domain/index/test_fanout_listener.py b/server/tests/unit/domain/index/test_fanout_listener.py deleted file mode 100644 index 2aa0988..0000000 --- a/server/tests/unit/domain/index/test_fanout_listener.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Unit tests for FanOutToIndexBackends handler.""" - -from typing import Any -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import DepositionSource -from osa.domain.shared.model.srn import ( - ConventionSRN, - DepositionSRN, - Domain, - LocalId, - RecordSRN, - RecordVersion, -) - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str): - self._name = name - - @property - def name(self) -> str: - return self._name - - -class FakeOutbox: - """Fake outbox for testing event emission.""" - - def __init__(self): - self.events: list[Any] = [] - self.append = AsyncMock(side_effect=self._append) - - async def _append(self, event: Any) -> None: - self.events.append(event) - - -@pytest.fixture -def sample_record_srn() -> RecordSRN: - """Create a sample record SRN.""" - return RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ) - - -@pytest.fixture -def sample_deposition_srn() -> DepositionSRN: - """Create a sample deposition SRN.""" - return DepositionSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - ) - - -@pytest.fixture -def sample_metadata() -> dict: - """Create sample metadata for testing.""" - return { - "title": "Test Record", - "organism": "human", - "platform": "GPL570", - } - - -class TestFanOutToIndexBackends: - """Tests for FanOutToIndexBackends handler.""" - - @pytest.mark.asyncio - async def test_creates_index_record_per_backend( - self, - sample_record_srn: RecordSRN, - sample_deposition_srn: DepositionSRN, - sample_metadata: dict, - ): - """Handler should create one IndexRecord event per registered backend.""" - # Arrange - backend1 = FakeBackend("vector") - backend2 = FakeBackend("keyword") - registry = IndexRegistry({"vector": backend1, "keyword": backend2}) - outbox = FakeOutbox() - - handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) - - event = RecordPublished( - id=EventId(uuid4()), - record_srn=sample_record_srn, - source=DepositionSource(id=str(sample_deposition_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=__import__( - "osa.domain.shared.model.srn", fromlist=["SchemaId"] - ).SchemaId.parse("test@1.0.0"), - metadata=sample_metadata, - ) - - # Act - await handler.handle(event) - - # Assert - assert len(outbox.events) == 2 - backend_names = {e.backend_name for e in outbox.events} - assert backend_names == {"vector", "keyword"} - - for index_event in outbox.events: - assert isinstance(index_event, IndexRecord) - assert index_event.record_srn == sample_record_srn - assert index_event.metadata == sample_metadata - - @pytest.mark.asyncio - async def test_creates_unique_event_ids( - self, - sample_record_srn: RecordSRN, - sample_deposition_srn: DepositionSRN, - sample_metadata: dict, - ): - """Each IndexRecord should have a unique event ID.""" - # Arrange - registry = IndexRegistry( - { - "backend1": FakeBackend("backend1"), - "backend2": FakeBackend("backend2"), - } - ) - outbox = FakeOutbox() - - handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) - - event = RecordPublished( - id=EventId(uuid4()), - record_srn=sample_record_srn, - source=DepositionSource(id=str(sample_deposition_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=__import__( - "osa.domain.shared.model.srn", fromlist=["SchemaId"] - ).SchemaId.parse("test@1.0.0"), - metadata=sample_metadata, - ) - - # Act - await handler.handle(event) - - # Assert - event_ids = [e.id for e in outbox.events] - assert len(event_ids) == len(set(event_ids)), "Event IDs should be unique" - - @pytest.mark.asyncio - async def test_empty_registry_creates_no_events( - self, - sample_record_srn: RecordSRN, - sample_deposition_srn: DepositionSRN, - sample_metadata: dict, - ): - """No IndexRecord events should be created if registry is empty.""" - # Arrange - registry = IndexRegistry({}) - outbox = FakeOutbox() - - handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) - - event = RecordPublished( - id=EventId(uuid4()), - record_srn=sample_record_srn, - source=DepositionSource(id=str(sample_deposition_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=__import__( - "osa.domain.shared.model.srn", fromlist=["SchemaId"] - ).SchemaId.parse("test@1.0.0"), - metadata=sample_metadata, - ) - - # Act - await handler.handle(event) - - # Assert - assert len(outbox.events) == 0 diff --git a/server/tests/unit/domain/index/test_index_service.py b/server/tests/unit/domain/index/test_index_service.py deleted file mode 100644 index 9769f97..0000000 --- a/server/tests/unit/domain/index/test_index_service.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Unit tests for IndexService.""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.index.service.index import IndexService - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str, count: int = 0, healthy: bool = True): - self._name = name - self._count = count - self._healthy = healthy - self.count = AsyncMock(return_value=count) - self.health = AsyncMock(return_value=healthy) - - @property - def name(self) -> str: - return self._name - - -class TestIndexService: - """Tests for IndexService.""" - - @pytest.mark.asyncio - async def test_get_count_returns_backend_count(self): - """get_count should return the backend's document count.""" - # Arrange - backend = FakeBackend("vector", count=42) - registry = IndexRegistry({"vector": backend}) - service = IndexService(indexes=registry) - - # Act - result = await service.get_count("vector") - - # Assert - assert result == 42 - backend.count.assert_called_once() - - @pytest.mark.asyncio - async def test_get_count_returns_none_for_unknown_backend(self): - """get_count should return None for unknown backend.""" - # Arrange - registry = IndexRegistry({}) - service = IndexService(indexes=registry) - - # Act - result = await service.get_count("unknown") - - # Assert - assert result is None - - @pytest.mark.asyncio - async def test_check_health_returns_backend_health(self): - """check_health should return the backend's health status.""" - # Arrange - backend = FakeBackend("vector", healthy=True) - registry = IndexRegistry({"vector": backend}) - service = IndexService(indexes=registry) - - # Act - result = await service.check_health("vector") - - # Assert - assert result is True - backend.health.assert_called_once() - - @pytest.mark.asyncio - async def test_check_health_returns_none_for_unknown_backend(self): - """check_health should return None for unknown backend.""" - # Arrange - registry = IndexRegistry({}) - service = IndexService(indexes=registry) - - # Act - result = await service.check_health("unknown") - - # Assert - assert result is None - - @pytest.mark.asyncio - async def test_check_health_returns_false_for_unhealthy_backend(self): - """check_health should return False for unhealthy backend.""" - # Arrange - backend = FakeBackend("vector", healthy=False) - registry = IndexRegistry({"vector": backend}) - service = IndexService(indexes=registry) - - # Act - result = await service.check_health("vector") - - # Assert - assert result is False diff --git a/server/tests/unit/domain/record/test_get_stats_handler.py b/server/tests/unit/domain/record/test_get_stats_handler.py new file mode 100644 index 0000000..985c83d --- /dev/null +++ b/server/tests/unit/domain/record/test_get_stats_handler.py @@ -0,0 +1,24 @@ +"""GetStats query handler — the stats route must not touch RecordRepository. + +Layering regression guard: Router → QueryHandler → Service → Repository. The +/stats route previously injected RecordRepository directly. +""" + +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.record.query.get_stats import GetStats, GetStatsHandler + + +class TestGetStatsHandler: + @pytest.mark.asyncio + async def test_returns_record_count_from_service(self): + service = AsyncMock() + service.count.return_value = 42 + + handler = GetStatsHandler(record_service=service) + result = await handler.run(GetStats()) + + assert result.records == 42 + service.count.assert_awaited_once() diff --git a/server/tests/unit/test_record_schema_srn_immutable.py b/server/tests/unit/domain/record/test_record_schema_srn_immutable.py similarity index 100% rename from server/tests/unit/test_record_schema_srn_immutable.py rename to server/tests/unit/domain/record/test_record_schema_srn_immutable.py diff --git a/server/tests/unit/domain/semantics/test_schema_reserved_name.py b/server/tests/unit/domain/semantics/test_schema_reserved_name.py new file mode 100644 index 0000000..335c47f --- /dev/null +++ b/server/tests/unit/domain/semantics/test_schema_reserved_name.py @@ -0,0 +1,40 @@ +"""T064 — Schema.__init__ rejects reserved schema IDs (records, datasets).""" + +from datetime import UTC, datetime + +import pytest + +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.error import ReservedNameError +from osa.domain.shared.model.srn import SchemaId + + +def _text_field(name: str = "title") -> FieldDefinition: + return FieldDefinition( + name=name, type=FieldType.TEXT, required=True, cardinality=Cardinality.EXACTLY_ONE + ) + + +@pytest.mark.parametrize("reserved", ["records", "datasets"]) +def test_schema_rejects_reserved_id(reserved: str) -> None: + with pytest.raises(ReservedNameError) as exc: + Schema( + id=SchemaId.parse(f"{reserved}@1.0.0"), + title="x", + fields=[_text_field()], + created_at=datetime.now(UTC), + ) + assert exc.value.code == "reserved_name" + assert exc.value.kind == "schema" + assert exc.value.name == reserved + + +def test_schema_allows_non_reserved_id() -> None: + schema = Schema( + id=SchemaId.parse("compound@1.0.0"), + title="Compound", + fields=[_text_field()], + created_at=datetime.now(UTC), + ) + assert schema.id.id.root == "compound" diff --git a/server/osa/domain/search/adapter/__init__.py b/server/tests/unit/domain/shared/__init__.py similarity index 100% rename from server/osa/domain/search/adapter/__init__.py rename to server/tests/unit/domain/shared/__init__.py diff --git a/server/tests/unit/domain/shared/test_record_ref.py b/server/tests/unit/domain/shared/test_record_ref.py new file mode 100644 index 0000000..40769dd --- /dev/null +++ b/server/tests/unit/domain/shared/test_record_ref.py @@ -0,0 +1,28 @@ +"""RecordRef — parse of ``{id}`` / ``{id}@{version}`` URL segments.""" + +import pytest + +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.ids import RecordRef + + +def test_parse_bare_id_has_no_version() -> None: + ref = RecordRef.parse("0190a1b2") + assert ref.id == "0190a1b2" + assert ref.version is None + + +def test_parse_id_with_version() -> None: + ref = RecordRef.parse("0190a1b2@3") + assert ref.id == "0190a1b2" + assert ref.version == 3 + + +def test_parse_non_integer_version_raises_validation_error() -> None: + with pytest.raises(ValidationError): + RecordRef.parse("0190a1b2@latest") + + +def test_render_round_trips() -> None: + assert RecordRef.parse("abc@7").render() == "abc@7" + assert RecordRef.parse("abc").render() == "abc" diff --git a/server/tests/unit/domain/shared/test_reserved.py b/server/tests/unit/domain/shared/test_reserved.py new file mode 100644 index 0000000..7d7ef21 --- /dev/null +++ b/server/tests/unit/domain/shared/test_reserved.py @@ -0,0 +1,33 @@ +"""T028 — RESERVED_NAMES constant and ReservedNameError.""" + +from osa.domain.shared.error import DomainError, ReservedNameError +from osa.domain.shared.model.reserved import RESERVED_NAMES + + +def test_reserved_names_contains_records_and_datasets() -> None: + assert RESERVED_NAMES == frozenset({"records", "datasets"}) + + +def test_reserved_names_is_frozen() -> None: + assert isinstance(RESERVED_NAMES, frozenset) + + +def test_reserved_name_error_is_domain_error() -> None: + err = ReservedNameError("records", "schema") + assert isinstance(err, DomainError) + + +def test_reserved_name_error_code_and_message() -> None: + err = ReservedNameError("records", "schema") + assert err.code == "reserved_name" + assert err.name == "records" + assert err.kind == "schema" + assert "records" in err.message + assert "/data/" in err.message + + +def test_reserved_name_error_lists_reserved_names() -> None: + err = ReservedNameError("datasets", "hook") + # sorted reserved set is rendered in the message + assert "datasets" in err.message + assert "records" in err.message diff --git a/server/osa/domain/search/command/__init__.py b/server/tests/unit/infrastructure/data/__init__.py similarity index 100% rename from server/osa/domain/search/command/__init__.py rename to server/tests/unit/infrastructure/data/__init__.py diff --git a/server/tests/unit/infrastructure/data/test_cursor_validation.py b/server/tests/unit/infrastructure/data/test_cursor_validation.py new file mode 100644 index 0000000..5f2f2cd --- /dev/null +++ b/server/tests/unit/infrastructure/data/test_cursor_validation.py @@ -0,0 +1,127 @@ +"""A malformed pagination cursor must surface as a 400, not a 500. + +``decode_cursor`` raises bare ``ValueError`` on corrupt base64 / missing keys, +and the feature sort coerces ``id`` with ``int(...)`` which also raises +``ValueError`` on a non-integer. Both run inside ``stream_rows``' pre-flight +pull; if they escape as raw ``ValueError`` (not an ``OSAError``) the global +handler returns 500. These tests pin them to ``ValidationError(field="cursor")`` +so the route maps them to 400 — exercised DB-free via the sort helpers, which +decode the cursor before any DB access. +""" + +import base64 +import json +from datetime import date, datetime + +import pytest +import sqlalchemy as sa + +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, + encode_cursor, +) +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, +) +from osa.infrastructure.persistence.tables import records_table + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _store() -> PostgresTableReadStore: + # The sort helpers never touch self.session; decoding happens before any DB + # access, so a placeholder session is sufficient for these unit tests. + return PostgresTableReadStore(None) # type: ignore[arg-type] + + +def _records_plan(cursor: str) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ) + + +def _feature_plan(cursor: str) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="chem_features", + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + + +def _feature_table(): + fschema = FeatureSchema(columns=[ColumnDef(name="score", json_type="number", required=True)]) + return build_feature_table("chem_features", fschema) + + +class TestRecordsCursorValidation: + def test_corrupt_base64_raises_validation_error(self) -> None: + with pytest.raises(ValidationError) as exc: + _store()._records_sort(_records_plan("!!not-base64!!"), records_table) + assert exc.value.field == "cursor" + + def test_missing_keys_raises_validation_error(self) -> None: + # A well-formed base64 JSON object that lacks the required s/id keys. + bad = base64.urlsafe_b64encode(json.dumps({"x": 1}).encode()).decode() + with pytest.raises(ValidationError) as exc: + _store()._records_sort(_records_plan(bad), records_table) + assert exc.value.field == "cursor" + + def test_date_metadata_cursor_value_coerced_to_date(self) -> None: + # A FieldType.DATE metadata column is a sa.Date in the dynamic table; + # the cursor carries its value as an ISO string — asyncpg rejects a + # str bound against DATE, so the decoder must coerce by column type. + value = PostgresTableReadStore._coerce_cursor_value( + "2026-01-02", sa.Column("assay_date", sa.Date()) + ) + assert isinstance(value, date) + + def test_datetime_metadata_cursor_value_coerced_to_datetime(self) -> None: + value = PostgresTableReadStore._coerce_cursor_value( + "2026-01-02T03:04:05+00:00", sa.Column("measured_at", sa.DateTime(timezone=True)) + ) + assert isinstance(value, datetime) + + +class TestFeatureCursorValidation: + def test_corrupt_base64_raises_validation_error(self) -> None: + ft = _feature_table() + with pytest.raises(ValidationError) as exc: + _store()._features_sort(_feature_plan("!!not-base64!!"), ft) + assert exc.value.field == "cursor" + + def test_non_integer_id_raises_validation_error(self) -> None: + ft = _feature_table() + # Well-formed cursor whose id component is not an int → int(...) ValueError. + cursor = encode_cursor(5, "not-an-int") + with pytest.raises(ValidationError) as exc: + _store()._features_sort(_feature_plan(cursor), ft) + assert exc.value.field == "cursor" + + def test_created_at_cursor_value_coerced_to_datetime(self) -> None: + # created_at is an implicit feature column; the type-driven coercion + # must turn its ISO string back into a datetime before binding. + ft = _feature_table() + value = PostgresTableReadStore._coerce_cursor_value( + "2026-01-02T03:04:05+00:00", ft.c.created_at + ) + assert isinstance(value, datetime) + + def test_integer_id_sort_value_coerced_to_int(self) -> None: + # Sorting by id, the cursor sort value binds against BIGINT — a str + # is coerced (and garbage raises ValueError → 400 via the sort path). + ft = _feature_table() + assert PostgresTableReadStore._coerce_cursor_value("7", ft.c.id) == 7 diff --git a/server/tests/unit/infrastructure/data/test_statement_timeout.py b/server/tests/unit/infrastructure/data/test_statement_timeout.py new file mode 100644 index 0000000..de46ab2 --- /dev/null +++ b/server/tests/unit/infrastructure/data/test_statement_timeout.py @@ -0,0 +1,17 @@ +"""The adapter owns SET LOCAL statement_timeout (routes hold no SQL session).""" + +from datetime import timedelta + +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore + +_sql = PostgresTableReadStore.statement_timeout_sql + + +def test_statement_timeout_sql_renders_integer_milliseconds() -> None: + assert _sql(timedelta(seconds=30)) == "SET LOCAL statement_timeout = '30000ms'" + assert _sql(timedelta(minutes=30)) == "SET LOCAL statement_timeout = '1800000ms'" + + +def test_statement_timeout_sql_truncates_subsecond() -> None: + # int milliseconds — no operator-supplied string ever reaches raw SQL. + assert _sql(timedelta(milliseconds=1500.7)) == "SET LOCAL statement_timeout = '1500ms'" diff --git a/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py b/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py deleted file mode 100644 index 25974c2..0000000 --- a/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Tests for PostgresFieldDefinitionReader — field type map construction.""" - -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.infrastructure.persistence.adapter.discovery import PostgresFieldDefinitionReader - - -def _make_schema_row(srn: str, fields: list[dict]) -> dict: - return {"srn": srn, "fields": fields} - - -@pytest.fixture -def mock_session() -> AsyncMock: - return AsyncMock() - - -def _setup_session_result(session: AsyncMock, rows: list[dict]) -> None: - """Configure mock session to return rows from a SELECT on schemas_table.""" - result_mock = MagicMock() - result_mock.mappings.return_value.all.return_value = rows - session.execute.return_value = result_mock - - -class TestGetAllFieldTypes: - async def test_builds_field_map_from_multiple_schemas(self, mock_session: AsyncMock) -> None: - rows = [ - _make_schema_row( - "urn:osa:localhost:schema:a@1", - [ - { - "name": "title", - "type": "text", - "required": True, - "cardinality": "exactly_one", - }, - { - "name": "resolution", - "type": "number", - "required": False, - "cardinality": "exactly_one", - }, - ], - ), - _make_schema_row( - "urn:osa:localhost:schema:b@1", - [ - { - "name": "method", - "type": "term", - "required": True, - "cardinality": "exactly_one", - "constraints": { - "type": "term", - "ontology_srn": "urn:osa:localhost:onto:methods@1", - }, - }, - ], - ), - ] - _setup_session_result(mock_session, rows) - - reader = PostgresFieldDefinitionReader(session=mock_session) - result = await reader.get_all_field_types() - - assert result == { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TERM, - } - - async def test_raises_on_conflicting_types(self, mock_session: AsyncMock) -> None: - rows = [ - _make_schema_row( - "urn:osa:localhost:schema:a@1", - [{"name": "value", "type": "text", "required": True, "cardinality": "exactly_one"}], - ), - _make_schema_row( - "urn:osa:localhost:schema:b@1", - [ - { - "name": "value", - "type": "number", - "required": True, - "cardinality": "exactly_one", - } - ], - ), - ] - _setup_session_result(mock_session, rows) - - reader = PostgresFieldDefinitionReader(session=mock_session) - with pytest.raises(ValidationError, match="value"): - await reader.get_all_field_types() - - async def test_returns_empty_map_when_no_schemas(self, mock_session: AsyncMock) -> None: - _setup_session_result(mock_session, []) - - reader = PostgresFieldDefinitionReader(session=mock_session) - result = await reader.get_all_field_types() - - assert result == {} - - async def test_same_field_same_type_across_schemas_ok(self, mock_session: AsyncMock) -> None: - rows = [ - _make_schema_row( - "urn:osa:localhost:schema:a@1", - [{"name": "title", "type": "text", "required": True, "cardinality": "exactly_one"}], - ), - _make_schema_row( - "urn:osa:localhost:schema:b@1", - [ - { - "name": "title", - "type": "text", - "required": False, - "cardinality": "exactly_one", - } - ], - ), - ] - _setup_session_result(mock_session, rows) - - reader = PostgresFieldDefinitionReader(session=mock_session) - result = await reader.get_all_field_types() - - assert result == {"title": FieldType.TEXT} diff --git a/server/tests/unit/test_field_ref_resolution.py b/server/tests/unit/test_field_ref_resolution.py deleted file mode 100644 index d647c3a..0000000 --- a/server/tests/unit/test_field_ref_resolution.py +++ /dev/null @@ -1,42 +0,0 @@ -"""US3 tests: typed field-reference parsing and tree validation.""" - -import pytest - -from osa.domain.discovery.model.refs import ( - FeatureFieldRef, - MetadataFieldRef, - parse_field_ref, -) - - -class TestParseFieldRef: - def test_parses_metadata_ref(self): - ref = parse_field_ref("metadata.species") - assert isinstance(ref, MetadataFieldRef) - assert ref.field == "species" - - def test_parses_feature_ref(self): - ref = parse_field_ref("features.cell_classifier.confidence") - assert isinstance(ref, FeatureFieldRef) - assert ref.hook == "cell_classifier" - assert ref.column == "confidence" - - def test_rejects_unknown_prefix(self): - with pytest.raises(ValueError, match="prefix"): - parse_field_ref("other.foo") - - def test_rejects_malformed_metadata(self): - with pytest.raises(ValueError): - parse_field_ref("metadata.a.b") - - def test_rejects_malformed_feature(self): - with pytest.raises(ValueError): - parse_field_ref("features.hook") - - def test_rejects_invalid_identifier(self): - with pytest.raises(ValueError): - parse_field_ref("metadata.Has-Dash") - - def test_dotted_round_trip(self): - assert parse_field_ref("metadata.species").dotted() == "metadata.species" - assert parse_field_ref("features.hook.col").dotted() == "features.hook.col" diff --git a/server/tests/unit/test_filter_expr_and_compile.py b/server/tests/unit/test_filter_expr_and_compile.py deleted file mode 100644 index 41a2ba9..0000000 --- a/server/tests/unit/test_filter_expr_and_compile.py +++ /dev/null @@ -1,197 +0,0 @@ -"""US1 tests: FilterExpr AND-tree validation via DiscoveryService bounds.""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import FeatureFieldRef, MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - FilterOperator, - Predicate, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import SchemaId - - -SCHEMA = SchemaId.parse("bio-sample@1.0.0") - - -def _config(overrides: dict | None = None) -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - cfg = Config() # type: ignore[call-arg] - if overrides: - for k, v in overrides.items(): - setattr(cfg, k, v) - return cfg - - -def _svc( - *, - field_map: dict[str, FieldType] | None = None, - max_depth: int | None = None, - max_preds: int | None = None, - max_joins: int | None = None, -) -> DiscoveryService: - read_store = AsyncMock() - read_store.search_records.return_value = [] - read_store.get_feature_catalog.return_value = [] - - reader = AsyncMock() - fm = field_map or { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - } - reader.get_all_field_types.return_value = fm - reader.get_fields_for_schema.return_value = fm - - overrides = {} - if max_depth is not None: - overrides["discovery_max_filter_depth"] = max_depth - if max_preds is not None: - overrides["discovery_max_predicates"] = max_preds - if max_joins is not None: - overrides["discovery_max_cross_domain_joins"] = max_joins - - return DiscoveryService(read_store=read_store, field_reader=reader, config=_config(overrides)) - - -def _pred(field: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=op, value=value) - - -class TestAndOnlyTrees: - async def test_accepts_and_of_predicates(self) -> None: - svc = _svc() - tree = And( - operands=[ - _pred("title", FilterOperator.EQ, "x"), - _pred("resolution", FilterOperator.GTE, 3.0), - ] - ) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestBoundsEnforced: - async def test_depth_exceeded(self) -> None: - svc = _svc(max_depth=3) - leaf = _pred("title", FilterOperator.EQ, "x") - tree = leaf - for _ in range(4): - tree = And(operands=[tree, leaf]) - - with pytest.raises(ValidationError, match="depth"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_predicates_exceeded(self) -> None: - svc = _svc(max_preds=2) - tree = And( - operands=[ - _pred("title", FilterOperator.EQ, "x"), - _pred("title", FilterOperator.EQ, "y"), - _pred("resolution", FilterOperator.GTE, 3.0), - ] - ) - with pytest.raises(ValidationError, match="predicate leaves"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_joins_exceeded(self) -> None: - svc = _svc(max_joins=1) - # Simulate catalog advertising multiple hooks with a column named score - svc.read_store.get_feature_catalog.return_value = [ # type: ignore[attr-defined] - type( - "E", - (), - { - "hook_name": "hook_a", - "columns": [ - type("C", (), {"name": "score", "type": "number", "required": True}) - ], - }, - ), - type( - "E", - (), - { - "hook_name": "hook_b", - "columns": [ - type("C", (), {"name": "score", "type": "number", "required": True}) - ], - }, - ), - ] - tree = And( - operands=[ - Predicate( - field=FeatureFieldRef(hook="hook_a", column="score"), - op=FilterOperator.GT, - value=0.0, - ), - Predicate( - field=FeatureFieldRef(hook="hook_b", column="score"), - op=FilterOperator.GT, - value=0.0, - ), - ] - ) - with pytest.raises(ValidationError, match="distinct feature hooks"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestUnknownField: - async def test_unknown_metadata_field_rejected(self) -> None: - svc = _svc() - with pytest.raises(ValidationError, match="Unknown metadata field"): - await svc.search_records( - filter_expr=_pred("bogus", FilterOperator.EQ, "x"), - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) diff --git a/server/tests/unit/test_filter_expr_or_not.py b/server/tests/unit/test_filter_expr_or_not.py deleted file mode 100644 index 191bc74..0000000 --- a/server/tests/unit/test_filter_expr_or_not.py +++ /dev/null @@ -1,131 +0,0 @@ -"""US2 tests: FilterExpr accepts OR/NOT trees and validation walks them correctly.""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - FilterOperator, - Not, - Or, - Predicate, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import SchemaId - - -SCHEMA = SchemaId.parse("bio-sample@1.0.0") - - -def _config() -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -def _svc() -> DiscoveryService: - read_store = AsyncMock() - read_store.search_records.return_value = [] - reader = AsyncMock() - fm = { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - } - reader.get_all_field_types.return_value = fm - reader.get_fields_for_schema.return_value = fm - return DiscoveryService(read_store=read_store, field_reader=reader, config=_config()) - - -def _pred(field: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=op, value=value) - - -class TestOrNot: - async def test_or_tree_accepted(self): - svc = _svc() - tree = Or( - operands=[ - _pred("title", FilterOperator.EQ, "A"), - _pred("title", FilterOperator.EQ, "B"), - ] - ) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - ) - - async def test_not_tree_accepted(self): - svc = _svc() - tree = Not(operand=_pred("title", FilterOperator.EQ, "X")) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - ) - - async def test_nested_mixed_tree(self): - svc = _svc() - tree = And( - operands=[ - _pred("title", FilterOperator.EQ, "X"), - Or( - operands=[ - _pred("resolution", FilterOperator.GTE, 3.0), - _pred("resolution", FilterOperator.LT, 1.0), - ] - ), - Not(operand=_pred("title", FilterOperator.EQ, "Bad")), - ] - ) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - ) - - -class TestCompoundDisabledFlag: - async def test_or_rejected_when_compound_disabled(self): - svc = _svc() - tree = Or( - operands=[ - _pred("title", FilterOperator.EQ, "A"), - _pred("title", FilterOperator.EQ, "B"), - ] - ) - with pytest.raises(ValidationError, match="compound_disabled|Compound"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - allow_compound=False, - ) diff --git a/server/tests/unit/test_no_index_deps.py b/server/tests/unit/test_no_index_deps.py new file mode 100644 index 0000000..37164ca --- /dev/null +++ b/server/tests/unit/test_no_index_deps.py @@ -0,0 +1,33 @@ +"""US6 footprint guard: the vector-index stack is gone from the image. + +`chromadb` and `sentence-transformers` were the heaviest runtime deps; the +unified `/data/` surface replaced the vector/keyword index domain, so neither +should be importable. If a future change reintroduces them as a transitive +dependency, these tests fail loudly rather than silently re-bloating the image. +""" + +import importlib + +import pytest + + +@pytest.mark.parametrize("module", ["chromadb", "sentence_transformers"]) +def test_index_dependency_not_importable(module: str): + with pytest.raises(ModuleNotFoundError): + importlib.import_module(module) + + +@pytest.mark.parametrize( + "module", + [ + "osa.domain.index", + "osa.domain.search", + "osa.domain.export", + "osa.domain.discovery", + "osa.infrastructure.index", + "osa.sdk.index", + ], +) +def test_deleted_domain_not_importable(module: str): + with pytest.raises(ModuleNotFoundError): + importlib.import_module(module) diff --git a/server/uv.lock b/server/uv.lock index 5953a12..2b9e5c1 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -395,6 +395,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/db/d291e30fdf7ea617a335531e72294e0c723356d7fdde8fba00610a76bda9/coverage-7.13.2-py3-none-any.whl", hash = "sha256:40ce1ea1e25125556d8e76bd0b61500839a07944cc287ac21d5626f3e620cad5", size = 210943, upload-time = "2026-01-25T13:00:02.388Z" }, ] +[[package]] +name = "deprecated" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/85/12f0a49a7c4ffb70572b6c2ef13c90c88fd190debda93b23f026b25f9634/deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223", size = 2932523, upload-time = "2025-10-30T08:19:02.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, +] + [[package]] name = "dishka" version = "1.9.1" @@ -677,6 +689,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/b3/a8917d253763095fb8dcaaefc6a135ed31abbd13f681e78752e226e252fe/kubernetes_asyncio-35.0.1-py3-none-any.whl", hash = "sha256:244ef45943e89c5c5104276a646bfcbf1a9dc3d060876c2094aa601e932f1c03", size = 2868606, upload-time = "2026-02-25T20:40:41.191Z" }, ] +[[package]] +name = "limits" +version = "5.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecated", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/69/826a5d1f45426c68d8f6539f8d275c0e4fcaa57f0c017ec3100986558a41/limits-5.8.0.tar.gz", hash = "sha256:c9e0d74aed837e8f6f50d1fcebcf5fd8130957287206bc3799adaee5092655da", size = 226104, upload-time = "2026-02-05T07:17:35.859Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/98/cb5ca20618d205a09d5bec7591fbc4130369c7e6308d9a676a28ff3ab22c/limits-5.8.0-py3-none-any.whl", hash = "sha256:ae1b008a43eb43073c3c579398bd4eb4c795de60952532dc24720ab45e1ac6b8", size = 60954, upload-time = "2026-02-05T07:17:34.425Z" }, +] + [[package]] name = "logfire" version = "4.21.0" @@ -1006,7 +1032,7 @@ wheels = [ [[package]] name = "osa" -version = "0.0.1" +version = "0.0.3" source = { editable = "." } dependencies = [ { name = "aiodocker", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1025,6 +1051,7 @@ dependencies = [ { name = "pydantic-settings", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyjwt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "python-multipart", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "slowapi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sqlalchemy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvicorn", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] @@ -1069,6 +1096,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "pyjwt", specifier = ">=2.11.0" }, { name = "python-multipart", specifier = ">=0.0.22" }, + { name = "slowapi", specifier = ">=0.1.9" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "types-aiobotocore-s3", marker = "extra == 'k8s'", specifier = ">=3.3.0" }, { name = "uvicorn", specifier = ">=0.38.0" }, @@ -1477,6 +1505,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "slowapi" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "limits", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a0/99/adfc7f94ca024736f061257d39118e1542bade7a52e86415a4c4ae92d8ff/slowapi-0.1.9.tar.gz", hash = "sha256:639192d0f1ca01b1c6d95bf6c71d794c3a9ee189855337b4821f7f457dddad77", size = 14028, upload-time = "2024-02-05T12:11:52.13Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/bb/f71c4b7d7e7eb3fc1e8c0458a8979b912f40b58002b9fbf37729b8cb464b/slowapi-0.1.9-py3-none-any.whl", hash = "sha256:cfad116cfb84ad9d763ee155c1e5c5cbf00b0d47399a769b227865f5df576e36", size = 14670, upload-time = "2024-02-05T12:11:50.898Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.46"