From fd32f6377be8d3389945f38e060cbd833b825e81 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Tue, 9 Jun 2026 15:22:45 +0100 Subject: [PATCH 01/23] feat: add unified /data/ read surface; remove legacy index/search/export Replace the fragmented consumer-read surface (/discovery, /records/{srn}, /search) with a single /data/ URL family owned by a new `data` domain: - Node catalog (GET /data) and machine-readable schema manifests (GET /data/{schema}). - Single record fetch by bare id (GET /data/records/{id}[@version]), returning both id and srn. - Records and feature-table reads at /data/{schema}/records and /data/{schema}/{feature} as paginated JSON or streamed CSV / gzipped CSV, with bounded memory via a server-side cursor and pre-flight error surfacing before any bytes. - FilterExpr POST body (rate-limited per route), keyset pagination, and a reserved-name policy (records/datasets) enforced at registration and read. Delete the index (ChromaDB), search, export, and discovery domains, the sdk/index package, the discovery infra adapter, and the legacy routes; drop the /stats index-counts field in favour of a /data hyperlink. chromadb and sentence-transformers are no longer dependencies. --- server/osa/application/api/rest/app.py | 21 +- server/osa/application/api/v1/errors.py | 2 + .../api/v1/routes/data/__init__.py | 50 ++ .../api/v1/routes/data/_limiter.py | 20 + .../application/api/v1/routes/data/_params.py | 86 +++ .../api/v1/routes/data/_runtime.py | 32 + .../api/v1/routes/data/_streaming.py | 101 +++ .../application/api/v1/routes/data/catalog.py | 37 + .../api/v1/routes/data/features_table.py | 120 +++ .../application/api/v1/routes/data/models.py | 32 + .../application/api/v1/routes/data/records.py | 44 ++ .../api/v1/routes/data/records_table.py | 108 +++ .../application/api/v1/routes/data/tables.py | 89 +++ .../application/api/v1/routes/discovery.py | 157 ---- .../osa/application/api/v1/routes/records.py | 47 -- .../osa/application/api/v1/routes/search.py | 75 -- server/osa/application/api/v1/routes/stats.py | 34 +- server/osa/application/di.py | 6 +- .../{discovery/model => data}/__init__.py | 0 .../port => data/model}/__init__.py | 0 server/osa/domain/data/model/catalog.py | 35 + .../model/value.py => data/model/filter.py} | 165 +++-- server/osa/domain/data/model/format.py | 54 ++ server/osa/domain/data/model/manifest.py | 70 ++ server/osa/domain/data/model/query_plan.py | 105 +++ .../osa/domain/data/model/record_summary.py | 50 ++ .../query => data/port}/__init__.py | 0 .../osa/domain/data/port/data_read_store.py | 45 ++ .../service => data/query}/__init__.py | 0 .../util => data/serializer}/__init__.py | 0 server/osa/domain/data/serializer/csv.py | 61 ++ server/osa/domain/data/serializer/csv_gzip.py | 41 + server/osa/domain/data/serializer/json.py | 38 + server/osa/domain/data/serializer/protocol.py | 34 + .../{export => data/service}/__init__.py | 0 .../osa/domain/data/service/data_catalog.py | 75 ++ server/osa/domain/data/service/data_query.py | 96 +++ .../{export/adapter => data/util}/__init__.py | 0 server/osa/domain/data/util/di/__init__.py | 3 + server/osa/domain/data/util/di/provider.py | 20 + server/osa/domain/discovery/__init__.py | 1 - server/osa/domain/discovery/model/refs.py | 78 -- .../discovery/port/field_definition_reader.py | 28 - .../osa/domain/discovery/port/read_store.py | 50 -- .../discovery/query/get_feature_catalog.py | 32 - .../domain/discovery/query/search_features.py | 57 -- .../domain/discovery/query/search_records.py | 59 -- .../osa/domain/discovery/service/discovery.py | 423 ----------- .../osa/domain/discovery/util/di/__init__.py | 3 - .../osa/domain/discovery/util/di/provider.py | 33 - server/osa/domain/index/__init__.py | 1 - server/osa/domain/index/event/__init__.py | 5 - server/osa/domain/index/event/index_record.py | 26 - server/osa/domain/index/handler/__init__.py | 7 - .../index/handler/fanout_to_index_backends.py | 47 -- .../index/handler/keyword_index_handler.py | 38 - .../index/handler/vector_index_handler.py | 47 -- server/osa/domain/index/model/__init__.py | 5 - server/osa/domain/index/model/registry.py | 33 - server/osa/domain/index/service/__init__.py | 5 - server/osa/domain/index/service/index.py | 48 -- server/osa/domain/search/__init__.py | 1 - server/osa/domain/search/command/__init__.py | 0 server/osa/domain/search/event/__init__.py | 0 server/osa/domain/search/model/__init__.py | 0 server/osa/domain/search/port/__init__.py | 0 server/osa/domain/search/query/__init__.py | 0 server/osa/domain/search/service/__init__.py | 0 server/osa/domain/semantics/model/schema.py | 6 +- server/osa/domain/shared/error.py | 20 + server/osa/domain/shared/model/hook.py | 9 + server/osa/domain/shared/model/ids.py | 20 + server/osa/domain/shared/model/reserved.py | 16 + .../data}/__init__.py | 0 .../data/postgres_data_read_store.py | 615 +++++++++++++++ server/osa/infrastructure/index/__init__.py | 1 - server/osa/infrastructure/index/di.py | 21 - .../persistence/adapter/discovery.py | 701 ------------------ server/osa/infrastructure/persistence/di.py | 19 +- server/osa/sdk/index/__init__.py | 12 - server/osa/sdk/index/backend.py | 91 --- server/osa/sdk/index/config.py | 12 - server/osa/sdk/index/result.py | 21 - server/pyproject.toml | 1 + .../tests/contract/test_worker_lifecycle.py | 218 ------ .../persistence/test_discovery_pagination.py | 80 -- .../integration/semantics}/__init__.py | 0 .../test_reserved_name_registration.py | 46 ++ .../tests/integration/test_crash_recovery.py | 211 ------ .../test_data_features_postgres.py | 219 ++++++ .../test_data_read_store_postgres.py | 265 +++++++ .../test_data_routes_e2e_postgres.py | 527 +++++++++++++ ...test_data_streaming_guarantees_postgres.py | 153 ++++ .../test_discovery_compound_postgres.py | 162 ---- .../test_discovery_cross_join_postgres.py | 254 ------- .../test_discovery_records_typed_and.py | 284 ------- .../test_event_batch_processing.py | 229 ------ .../application/api/v1/routes}/__init__.py | 0 .../api/v1/routes/data}/__init__.py | 0 .../api/v1/routes/data/test_params.py | 64 ++ .../api/v1/routes/data/test_streaming.py | 103 +++ .../api/v1/routes/data/test_tables_factory.py | 87 +++ .../unit/domain/data}/__init__.py | 0 .../unit/domain/data/serializer}/__init__.py | 0 .../unit/domain/data/serializer/test_csv.py | 64 ++ .../domain/data/serializer/test_csv_gzip.py | 56 ++ .../unit/domain/data/serializer/test_json.py | 53 ++ .../domain/data/test_data_query_service.py | 112 +++ .../tests/unit/domain/data/test_query_plan.py | 61 ++ .../deposition/test_hook_reserved_name.py | 36 + .../tests/unit/domain/discovery/__init__.py | 0 .../discovery/test_discovery_service.py | 476 ------------ .../discovery/test_get_feature_catalog.py | 114 --- .../domain/discovery/test_search_features.py | 275 ------- .../domain/discovery/test_search_records.py | 69 -- .../tests/unit/domain/discovery/test_value.py | 107 --- server/tests/unit/domain/index/__init__.py | 0 .../unit/domain/index/test_fanout_listener.py | 185 ----- .../unit/domain/index/test_index_service.py | 97 --- .../semantics/test_schema_reserved_name.py | 40 + .../unit/domain/shared}/__init__.py | 0 .../tests/unit/domain/shared/test_reserved.py | 33 + .../test_field_definition_reader.py | 130 ---- .../tests/unit/test_field_ref_resolution.py | 42 -- .../unit/test_filter_expr_and_compile.py | 197 ----- server/tests/unit/test_filter_expr_or_not.py | 131 ---- server/tests/unit/test_no_index_deps.py | 33 + server/uv.lock | 42 +- 128 files changed, 4248 insertions(+), 5552 deletions(-) create mode 100644 server/osa/application/api/v1/routes/data/__init__.py create mode 100644 server/osa/application/api/v1/routes/data/_limiter.py create mode 100644 server/osa/application/api/v1/routes/data/_params.py create mode 100644 server/osa/application/api/v1/routes/data/_runtime.py create mode 100644 server/osa/application/api/v1/routes/data/_streaming.py create mode 100644 server/osa/application/api/v1/routes/data/catalog.py create mode 100644 server/osa/application/api/v1/routes/data/features_table.py create mode 100644 server/osa/application/api/v1/routes/data/models.py create mode 100644 server/osa/application/api/v1/routes/data/records.py create mode 100644 server/osa/application/api/v1/routes/data/records_table.py create mode 100644 server/osa/application/api/v1/routes/data/tables.py delete mode 100644 server/osa/application/api/v1/routes/discovery.py delete mode 100644 server/osa/application/api/v1/routes/records.py delete mode 100644 server/osa/application/api/v1/routes/search.py rename server/osa/domain/{discovery/model => data}/__init__.py (100%) rename server/osa/domain/{discovery/port => data/model}/__init__.py (100%) create mode 100644 server/osa/domain/data/model/catalog.py rename server/osa/domain/{discovery/model/value.py => data/model/filter.py} (55%) create mode 100644 server/osa/domain/data/model/format.py create mode 100644 server/osa/domain/data/model/manifest.py create mode 100644 server/osa/domain/data/model/query_plan.py create mode 100644 server/osa/domain/data/model/record_summary.py rename server/osa/domain/{discovery/query => data/port}/__init__.py (100%) create mode 100644 server/osa/domain/data/port/data_read_store.py rename server/osa/domain/{discovery/service => data/query}/__init__.py (100%) rename server/osa/domain/{discovery/util => data/serializer}/__init__.py (100%) create mode 100644 server/osa/domain/data/serializer/csv.py create mode 100644 server/osa/domain/data/serializer/csv_gzip.py create mode 100644 server/osa/domain/data/serializer/json.py create mode 100644 server/osa/domain/data/serializer/protocol.py rename server/osa/domain/{export => data/service}/__init__.py (100%) create mode 100644 server/osa/domain/data/service/data_catalog.py create mode 100644 server/osa/domain/data/service/data_query.py rename server/osa/domain/{export/adapter => data/util}/__init__.py (100%) create mode 100644 server/osa/domain/data/util/di/__init__.py create mode 100644 server/osa/domain/data/util/di/provider.py delete mode 100644 server/osa/domain/discovery/__init__.py delete mode 100644 server/osa/domain/discovery/model/refs.py delete mode 100644 server/osa/domain/discovery/port/field_definition_reader.py delete mode 100644 server/osa/domain/discovery/port/read_store.py delete mode 100644 server/osa/domain/discovery/query/get_feature_catalog.py delete mode 100644 server/osa/domain/discovery/query/search_features.py delete mode 100644 server/osa/domain/discovery/query/search_records.py delete mode 100644 server/osa/domain/discovery/service/discovery.py delete mode 100644 server/osa/domain/discovery/util/di/__init__.py delete mode 100644 server/osa/domain/discovery/util/di/provider.py delete mode 100644 server/osa/domain/index/__init__.py delete mode 100644 server/osa/domain/index/event/__init__.py delete mode 100644 server/osa/domain/index/event/index_record.py delete mode 100644 server/osa/domain/index/handler/__init__.py delete mode 100644 server/osa/domain/index/handler/fanout_to_index_backends.py delete mode 100644 server/osa/domain/index/handler/keyword_index_handler.py delete mode 100644 server/osa/domain/index/handler/vector_index_handler.py delete mode 100644 server/osa/domain/index/model/__init__.py delete mode 100644 server/osa/domain/index/model/registry.py delete mode 100644 server/osa/domain/index/service/__init__.py delete mode 100644 server/osa/domain/index/service/index.py delete mode 100644 server/osa/domain/search/__init__.py delete mode 100644 server/osa/domain/search/command/__init__.py delete mode 100644 server/osa/domain/search/event/__init__.py delete mode 100644 server/osa/domain/search/model/__init__.py delete mode 100644 server/osa/domain/search/port/__init__.py delete mode 100644 server/osa/domain/search/query/__init__.py delete mode 100644 server/osa/domain/search/service/__init__.py create mode 100644 server/osa/domain/shared/model/ids.py create mode 100644 server/osa/domain/shared/model/reserved.py rename server/osa/{domain/export/command => infrastructure/data}/__init__.py (100%) create mode 100644 server/osa/infrastructure/data/postgres_data_read_store.py delete mode 100644 server/osa/infrastructure/index/__init__.py delete mode 100644 server/osa/infrastructure/index/di.py delete mode 100644 server/osa/infrastructure/persistence/adapter/discovery.py delete mode 100644 server/osa/sdk/index/__init__.py delete mode 100644 server/osa/sdk/index/backend.py delete mode 100644 server/osa/sdk/index/config.py delete mode 100644 server/osa/sdk/index/result.py delete mode 100644 server/tests/contract/test_worker_lifecycle.py delete mode 100644 server/tests/integration/persistence/test_discovery_pagination.py rename server/{osa/domain/export/event => tests/integration/semantics}/__init__.py (100%) create mode 100644 server/tests/integration/semantics/test_reserved_name_registration.py delete mode 100644 server/tests/integration/test_crash_recovery.py create mode 100644 server/tests/integration/test_data_features_postgres.py create mode 100644 server/tests/integration/test_data_read_store_postgres.py create mode 100644 server/tests/integration/test_data_routes_e2e_postgres.py create mode 100644 server/tests/integration/test_data_streaming_guarantees_postgres.py delete mode 100644 server/tests/integration/test_discovery_compound_postgres.py delete mode 100644 server/tests/integration/test_discovery_cross_join_postgres.py delete mode 100644 server/tests/integration/test_discovery_records_typed_and.py delete mode 100644 server/tests/integration/test_event_batch_processing.py rename server/{osa/domain/export/model => tests/unit/application/api/v1/routes}/__init__.py (100%) rename server/{osa/domain/export/port => tests/unit/application/api/v1/routes/data}/__init__.py (100%) create mode 100644 server/tests/unit/application/api/v1/routes/data/test_params.py create mode 100644 server/tests/unit/application/api/v1/routes/data/test_streaming.py create mode 100644 server/tests/unit/application/api/v1/routes/data/test_tables_factory.py rename server/{osa/domain/export/query => tests/unit/domain/data}/__init__.py (100%) rename server/{osa/domain/export/service => tests/unit/domain/data/serializer}/__init__.py (100%) create mode 100644 server/tests/unit/domain/data/serializer/test_csv.py create mode 100644 server/tests/unit/domain/data/serializer/test_csv_gzip.py create mode 100644 server/tests/unit/domain/data/serializer/test_json.py create mode 100644 server/tests/unit/domain/data/test_data_query_service.py create mode 100644 server/tests/unit/domain/data/test_query_plan.py create mode 100644 server/tests/unit/domain/deposition/test_hook_reserved_name.py delete mode 100644 server/tests/unit/domain/discovery/__init__.py delete mode 100644 server/tests/unit/domain/discovery/test_discovery_service.py delete mode 100644 server/tests/unit/domain/discovery/test_get_feature_catalog.py delete mode 100644 server/tests/unit/domain/discovery/test_search_features.py delete mode 100644 server/tests/unit/domain/discovery/test_search_records.py delete mode 100644 server/tests/unit/domain/discovery/test_value.py delete mode 100644 server/tests/unit/domain/index/__init__.py delete mode 100644 server/tests/unit/domain/index/test_fanout_listener.py delete mode 100644 server/tests/unit/domain/index/test_index_service.py create mode 100644 server/tests/unit/domain/semantics/test_schema_reserved_name.py rename server/{osa/domain/search/adapter => tests/unit/domain/shared}/__init__.py (100%) create mode 100644 server/tests/unit/domain/shared/test_reserved.py delete mode 100644 server/tests/unit/infrastructure/persistence/test_field_definition_reader.py delete mode 100644 server/tests/unit/test_field_ref_resolution.py delete mode 100644 server/tests/unit/test_filter_expr_and_compile.py delete mode 100644 server/tests/unit/test_filter_expr_or_not.py create mode 100644 server/tests/unit/test_no_index_deps.py diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 6d175ae..63925cf 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -8,6 +8,7 @@ from dishka import Provider as DishkaProvider from fastapi import FastAPI, Request from fastapi.responses import JSONResponse +from slowapi.errors import RateLimitExceeded from sqlalchemy.ext.asyncio import AsyncEngine from osa.application.api.v1.errors import map_osa_error @@ -16,17 +17,16 @@ auth, conventions, depositions, - discovery, events, ingestions, health, ontologies, - records, schemas, - search, stats, validation, ) +from osa.application.api.v1.routes import data as data_routes +from osa.application.api.v1.routes.data._limiter import limiter from osa.application.di import create_container from osa.config import DEV_JWT_SECRET, Config from osa.domain.shared.authorization.startup import validate_all_handlers @@ -194,8 +194,6 @@ def create_app( app_instance.include_router(admin.router, prefix="/api/v1") app_instance.include_router(auth.router, prefix="/api/v1") app_instance.include_router(events.router, prefix="/api/v1") - app_instance.include_router(records.router, prefix="/api/v1") - app_instance.include_router(search.router, prefix="/api/v1") app_instance.include_router(stats.router, prefix="/api/v1") app_instance.include_router(ontologies.router, prefix="/api/v1") app_instance.include_router(schemas.router, prefix="/api/v1") @@ -203,7 +201,18 @@ def create_app( app_instance.include_router(depositions.router, prefix="/api/v1") app_instance.include_router(ingestions.router, prefix="/api/v1") app_instance.include_router(validation.router, prefix="/api/v1") - app_instance.include_router(discovery.router, prefix="/api/v1") + app_instance.include_router(data_routes.router, prefix="/api/v1") + + # POST /data/* rate limiting (slowapi). The shared limiter is attached to + # app state and its 429 handler registered; GET routes are not limited. + app_instance.state.limiter = limiter + + @app_instance.exception_handler(RateLimitExceeded) + async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + return JSONResponse( + status_code=429, + content={"code": "rate_limited", "message": "POST rate limit exceeded (10/min/IP)."}, + ) # Global OSA error handler - maps domain and infrastructure errors to HTTP responses @app_instance.exception_handler(OSAError) diff --git a/server/osa/application/api/v1/errors.py b/server/osa/application/api/v1/errors.py index 8a9d4cf..dd4c790 100644 --- a/server/osa/application/api/v1/errors.py +++ b/server/osa/application/api/v1/errors.py @@ -15,6 +15,7 @@ InvalidStateError, NotFoundError, OSAError, + ReservedNameError, ValidationError, ) @@ -24,6 +25,7 @@ InvalidStateError: 409, ConflictError: 409, AuthorizationError: 403, + ReservedNameError: 400, } diff --git a/server/osa/application/api/v1/routes/data/__init__.py b/server/osa/application/api/v1/routes/data/__init__.py new file mode 100644 index 0000000..8109e25 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/__init__.py @@ -0,0 +1,50 @@ +"""Unified ``/data/`` read surface router. + +Subroutes are registered by the user-story phases: +- catalog + manifest (``GET /data``, ``GET /data/{schema}``) — US3 +- single record by ID (``GET /data/records/{id}``) — US4 +- records table matrix (``/data/{schema}/records*``) — US1/US2 via the factory +- feature table matrix (``/data/{schema}/{feature}*``) — US5 via the factory + +The router coexists with the legacy ``/discovery``, ``/records``, ``/search`` +routers until US6 removes them (research §10). +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import DishkaRoute +from fastapi import APIRouter + +from osa.application.api.v1.routes.data import ( + catalog, + features_table, + records, + records_table, +) +from osa.domain.data.model.catalog import NodeCatalog + +router = APIRouter(prefix="/data", tags=["data"], route_class=DishkaRoute) + +# ``GET /data`` (empty sub-path) is registered directly on the prefixed router. +router.add_api_route( + "", + catalog.get_node_catalog, + methods=["GET"], + operation_id="data_get_node_catalog", + response_model=NodeCatalog, +) + +# Table-shaped routes (records + feature tables) are registered via the factory +# onto a DishkaRoute-enabled subrouter. Records is registered before the feature +# table's ``/{schema}/{feature}`` catch-all so ``/{schema}/records`` matches the +# records routes rather than being captured as a feature named "records". +tables_router = APIRouter(route_class=DishkaRoute) +records_table.register(tables_router) +features_table.register(tables_router) + +# Order matters: literal ``/records/{id}`` and the ``/{schema}/records*`` table +# routes must precede the manifest's ``/{schema}`` catch-all so a record fetch +# or table read isn't captured as a schema manifest lookup. +router.include_router(records.router) +router.include_router(tables_router) +router.include_router(catalog.manifest_router) diff --git a/server/osa/application/api/v1/routes/data/_limiter.py b/server/osa/application/api/v1/routes/data/_limiter.py new file mode 100644 index 0000000..fe924d8 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_limiter.py @@ -0,0 +1,20 @@ +"""Shared slowapi limiter for ``/data/`` POST routes (research §5). + +POST routes accept structured ``FilterExpr`` bodies that can express expensive +queries, so they are rate-limited per source IP. GET routes are intentionally +unlimited — stable GET URLs are exactly what we want CDNs to cache and +consumers to curl on a cron. + +Default: 10 requests/minute/IP on POST routes. The limiter instance is shared +(module singleton) and registered on the FastAPI app at startup; route handlers +apply it via ``@limiter.limit(POST_RATE_LIMIT)``. +""" + +from __future__ import annotations + +from slowapi import Limiter +from slowapi.util import get_remote_address + +POST_RATE_LIMIT = "10/minute" + +limiter = Limiter(key_func=get_remote_address) diff --git a/server/osa/application/api/v1/routes/data/_params.py b/server/osa/application/api/v1/routes/data/_params.py new file mode 100644 index 0000000..d36c334 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_params.py @@ -0,0 +1,86 @@ +"""Shared request parsing for table routes — sort spec, filter body, plan build.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +from osa.domain.data.model.filter import FilterExpr +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, +) +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.ids import HookName +from osa.domain.shared.model.srn import SchemaId + +# Page-size ceiling (mirrors PaginationParams.limit's bound). Requests above this +# are clamped, not rejected (FR — a consumer asking for "everything" with a big +# number gets the max page, not a 422). +MAX_LIMIT = 1000 + + +def clamp_limit(limit: int) -> int: + return max(1, min(limit, MAX_LIMIT)) + + +class FilterRequestBody(BaseModel): + """POST body shared by every table format (records + feature).""" + + filter: FilterExpr | None = None + cursor: str | None = None + # Unbounded at the edge so an over-large request is clamped (build_plan), + # not rejected; the canonical bound lives on PaginationParams. + limit: int = Field(default=50, ge=1) + sort: str | None = None + + +def parse_sort(raw: str | None) -> list[SortSpec]: + """Parse ``col[:asc|:desc],col2[:asc|:desc]`` → SortSpec list (empty if None).""" + if not raw: + return [] + specs: list[SortSpec] = [] + for part in raw.split(","): + token = part.strip() + if not token: + continue + if ":" in token: + column, direction = token.split(":", 1) + try: + dir_enum = SortDirection(direction.strip().lower()) + except ValueError as exc: + raise ValidationError( + f"Invalid sort direction in {token!r}; expected 'asc' or 'desc'.", + field="sort", + ) from exc + else: + column, dir_enum = token, SortDirection.ASC + specs.append(SortSpec(column=column.strip(), direction=dir_enum)) + return specs + + +def build_plan( + *, + schema_id: SchemaId, + table_kind: TableKind, + feature_name: HookName | None, + filter_expr: FilterExpr | None, + cursor: str | None, + limit: int, + sort: str | None, +) -> QueryPlan: + pagination = PaginationParams( + cursor=PaginationCursor(value=cursor) if cursor else None, + limit=clamp_limit(limit), + ) + return QueryPlan( + schema_id=schema_id, + table_kind=table_kind, + feature_name=feature_name, + filter=filter_expr, + pagination=pagination, + sort=parse_sort(sort), + ) diff --git a/server/osa/application/api/v1/routes/data/_runtime.py b/server/osa/application/api/v1/routes/data/_runtime.py new file mode 100644 index 0000000..9149f96 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_runtime.py @@ -0,0 +1,32 @@ +"""Per-route Postgres statement-timeout dependency (research §7). + +Each :class:`DataResponseFormat` carries its own ``statement_timeout`` (30s for +paginated JSON, 30min for streaming dumps). ``apply_statement_timeout`` issues +``SET LOCAL statement_timeout = '...'`` against the active session at the start +of the request; ``SET LOCAL`` reverts on commit/rollback so the value never +leaks to a later request reusing the pooled connection. + +``idle_in_transaction_session_timeout`` is set globally in Postgres config (see +quickstart), not per-route — it's the same everywhere and protects the pool +from any route's misbehaviour. +""" + +from __future__ import annotations + +import re + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.data.model.format import DataResponseFormat + +# Timeout literals are operator-controlled (from FORMATS), never user input, but +# validate anyway so a malformed registry value can't reach raw SQL. +_TIMEOUT_RE = re.compile(r"^\d+(ms|s|min)?$") + + +async def apply_statement_timeout(session: AsyncSession, fmt: DataResponseFormat) -> None: + timeout = fmt.statement_timeout + if not _TIMEOUT_RE.match(timeout): + raise ValueError(f"Invalid statement_timeout literal: {timeout!r}") + await session.execute(text(f"SET LOCAL statement_timeout = '{timeout}'")) diff --git a/server/osa/application/api/v1/routes/data/_streaming.py b/server/osa/application/api/v1/routes/data/_streaming.py new file mode 100644 index 0000000..7325719 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/_streaming.py @@ -0,0 +1,101 @@ +"""Response assembly for table reads — streaming and paginated paths. + +Two shapes share one entry point :func:`build_table_response`: + +* **Streaming** (``.csv`` / ``.csv.gz``): pre-flight pulls the first row inside + a try (research §4). A parse/validation/planner error raised before the first + row propagates to the route → mapped to a 4xx/404 *before any bytes*. On + success the first row is chained back in and the whole iterator is streamed. + +* **Paginated** (JSON): consume up to ``limit + 1`` rows; if a ``limit+1``-th + row exists there's a next page, so derive ``next_cursor`` from the last + returned row's ``(sort_value, srn)`` pair. The bounded page (``limit`` ≤ 1000) + is then rendered by the JSON serializer. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any + +from fastapi.responses import StreamingResponse + +from osa.domain.data.model.format import DataResponseFormat +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import QueryPlan, encode_cursor + + +async def build_table_response( + rows: AsyncIterator[Mapping[str, Any]], + fmt: DataResponseFormat, + columns: Sequence[ColumnSpec], + plan: QueryPlan, +) -> StreamingResponse: + if fmt.paginated: + return await _paginated_response(rows, fmt, columns, plan) + return await _streaming_response(rows, fmt, columns) + + +async def _streaming_response( + rows: AsyncIterator[Mapping[str, Any]], + fmt: DataResponseFormat, + columns: Sequence[ColumnSpec], +) -> StreamingResponse: + iterator = rows.__aiter__() + # Pre-flight: surface setup/validation errors before the first byte. + try: + first = await iterator.__anext__() + empty = False + except StopAsyncIteration: + empty = True + + async def chained() -> AsyncIterator[Mapping[str, Any]]: + if not empty: + yield first + async for row in iterator: + yield row + + serializer = fmt.make_serializer() + return StreamingResponse( + serializer.stream(chained(), columns), + media_type=fmt.media_type, + ) + + +async def _paginated_response( + rows: AsyncIterator[Mapping[str, Any]], + fmt: DataResponseFormat, + columns: Sequence[ColumnSpec], + plan: QueryPlan, +) -> StreamingResponse: + limit = plan.pagination.limit + page: list[Mapping[str, Any]] = [] + has_more = False + async for row in rows: + if len(page) == limit: + has_more = True + break + page.append(row) + + next_cursor = _next_cursor(page, plan) if has_more else None + + async def page_iter() -> AsyncIterator[Mapping[str, Any]]: + for row in page: + yield row + + serializer = fmt.make_serializer() + return StreamingResponse( + serializer.stream(page_iter(), columns, next_cursor=next_cursor), + media_type=fmt.media_type, + ) + + +def _next_cursor(page: list[Mapping[str, Any]], plan: QueryPlan) -> str | None: + if not page: + return None + last = page[-1] + sort_column = plan.sort[0].column + sort_value = last.get(sort_column) + # The keyset tiebreaker is the record SRN (see PostgresDataReadStore), so the + # cursor's id component is the full srn, not the bare id. + return encode_cursor(sort_value, last.get("srn", last.get("id"))) diff --git a/server/osa/application/api/v1/routes/data/catalog.py b/server/osa/application/api/v1/routes/data/catalog.py new file mode 100644 index 0000000..8b1a3c5 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/catalog.py @@ -0,0 +1,37 @@ +"""Catalog & manifest handlers — ``GET /data`` and ``GET /data/{schema}``. + +JSON-only (no format suffix). Reserved schema names (``records``, ``datasets``) +and unknown schemas surface as 404 via the service's ``NotFoundError``. + +The node-catalog handler is registered directly on the prefixed ``/data`` +router (an empty sub-path can't be ``include_router``-ed); the manifest handler +lives on ``manifest_router`` so its ``/{schema}`` catch-all can be ordered +after the literal ``/records/{id}`` route. +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import SchemaManifest +from osa.domain.data.service.data_catalog import DataCatalogService + +manifest_router = APIRouter(route_class=DishkaRoute) + + +async def get_node_catalog(service: FromDishka[DataCatalogService]) -> NodeCatalog: + """List schemas hosted at this node.""" + return await service.get_node_catalog() + + +@manifest_router.get( + "/{schema}", operation_id="data_get_schema_manifest", response_model=SchemaManifest +) +async def get_schema_manifest( + schema: str, service: FromDishka[DataCatalogService] +) -> SchemaManifest: + """Machine-readable manifest for a schema (`` or `@`).""" + schema_id = await service.resolve_schema(schema) + return await service.get_schema_manifest(schema_id) diff --git a/server/osa/application/api/v1/routes/data/features_table.py b/server/osa/application/api/v1/routes/data/features_table.py new file mode 100644 index 0000000..47797a0 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/features_table.py @@ -0,0 +1,120 @@ +"""Feature-table routes — ``/data/{schema}/{feature}[.csv|.csv.gz]`` (US5). + +Identical shape to the records table — same factory, same streaming/pagination +engine — but the path carries a ``{feature}`` segment and the plan is a +``TableKind.FEATURE`` plan. The feature must be a table resource on the resolved +schema's manifest; an unknown feature 404s before any bytes (T090). +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from dishka.integrations.fastapi import FromDishka +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter +from osa.application.api.v1.routes.data._params import FilterRequestBody, build_plan +from osa.application.api.v1.routes.data._runtime import apply_statement_timeout +from osa.application.api.v1.routes.data._streaming import build_table_response +from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.domain.data.model.format import DataResponseFormat +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import TableKind +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.service.data_query import DataQueryService +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.model.srn import SchemaId + + +async def _feature_columns( + catalog: DataCatalogService, schema: str, feature: str +) -> tuple[Sequence[ColumnSpec], SchemaId]: + """Resolve schema + feature (404 if unknown/reserved) → feature columns + SchemaId.""" + schema_id = await catalog.resolve_schema(schema) + manifest = await catalog.get_schema_manifest(schema_id) + resource = next( + ( + tr + for tr in manifest.table_resources + if tr.name == feature and tr.kind == TableKind.FEATURE + ), + None, + ) + if resource is None: + raise NotFoundError( + f"No feature table '{feature}' on schema '{schema}'. " + f"See /api/v1/data/{schema_id.render()} for its table resources.", + code="feature_not_found", + ) + return resource.columns, schema_id + + +def _make_get_endpoint(fmt: DataResponseFormat): + async def endpoint( + schema: str, + feature: str, + query_service: FromDishka[DataQueryService], + catalog_service: FromDishka[DataCatalogService], + session: FromDishka[AsyncSession], + cursor: str | None = None, + limit: int = 50, + sort: str | None = None, + ) -> StreamingResponse: + columns, schema_id = await _feature_columns(catalog_service, schema, feature) + await apply_statement_timeout(session, fmt) + plan = build_plan( + schema_id=schema_id, + table_kind=TableKind.FEATURE, + feature_name=feature, + filter_expr=None, + cursor=cursor, + limit=limit, + sort=sort, + ) + rows = query_service.stream_features(plan) + return await build_table_response(rows, fmt, columns, plan) + + return endpoint + + +def _make_post_endpoint(fmt: DataResponseFormat): + async def endpoint( + request: Request, + schema: str, + feature: str, + body: FilterRequestBody, + query_service: FromDishka[DataQueryService], + catalog_service: FromDishka[DataCatalogService], + session: FromDishka[AsyncSession], + ) -> StreamingResponse: + columns, schema_id = await _feature_columns(catalog_service, schema, feature) + await apply_statement_timeout(session, fmt) + plan = build_plan( + schema_id=schema_id, + table_kind=TableKind.FEATURE, + feature_name=feature, + filter_expr=body.filter, + cursor=body.cursor, + limit=body.limit, + sort=body.sort, + ) + rows = query_service.stream_features(plan) + return await build_table_response(rows, fmt, columns, plan) + + # Unique name before the limiter — see records_table._make_post_endpoint. + endpoint.__name__ = f"feature_post_{format_key(fmt)}" + endpoint.__qualname__ = endpoint.__name__ + return limiter.limit(POST_RATE_LIMIT)(endpoint) + + +def register(router: APIRouter) -> None: + register_table_routes( + router, + "/{schema}/{feature}", + _make_get_endpoint, + _make_post_endpoint, + "feature", + ) diff --git a/server/osa/application/api/v1/routes/data/models.py b/server/osa/application/api/v1/routes/data/models.py new file mode 100644 index 0000000..7b7aaea --- /dev/null +++ b/server/osa/application/api/v1/routes/data/models.py @@ -0,0 +1,32 @@ +"""Shared Pydantic response models for the ``/data/`` routes.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from osa.domain.data.model.record_summary import RecordSummary + + +class RecordResponse(BaseModel): + """Single-record response — carries BOTH the bare ``id`` and full ``srn``.""" + + id: str + srn: str + schema_id: str + version: int + metadata: dict[str, Any] + created_at: datetime + + @classmethod + def from_summary(cls, summary: RecordSummary) -> "RecordResponse": + return cls( + id=str(summary.id), + srn=str(summary.srn), + schema_id=summary.schema_id.render(), + version=summary.version, + metadata=summary.metadata, + created_at=summary.created_at, + ) diff --git a/server/osa/application/api/v1/routes/data/records.py b/server/osa/application/api/v1/routes/data/records.py new file mode 100644 index 0000000..0910e49 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/records.py @@ -0,0 +1,44 @@ +"""Single-record-by-ID route — ``GET /data/records/{id}[@{version}]`` (US4). + +Resolves a published record by its bare internal ID; the server finds the +schema via primary-key lookup. The response carries both ``id`` and ``srn``. +A bare ``GET /data/records`` (no id) is not defined — that slot is reserved for +the deferred cross-schema bulk read and 404s naturally. +""" + +from __future__ import annotations + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter + +from osa.application.api.v1.routes.data.models import RecordResponse +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.ids import RecordId + +router = APIRouter(route_class=DishkaRoute) + + +def _parse_id_and_version(raw: str) -> tuple[RecordId, int | None]: + """Split ``{id}`` or ``{id}@{version}`` into a typed id + optional version.""" + if "@" in raw: + id_part, version_part = raw.split("@", 1) + try: + return RecordId(id_part), int(version_part) + except ValueError as exc: + raise ValidationError( + f"Invalid record version in {raw!r}; expected an integer.", + field="id", + ) from exc + return RecordId(raw), None + + +@router.get( + "/records/{record_id}", operation_id="data_get_record_by_id", response_model=RecordResponse +) +async def get_record_by_id( + record_id: str, service: FromDishka[DataCatalogService] +) -> RecordResponse: + rid, version = _parse_id_and_version(record_id) + summary = await service.get_record_by_id(rid, version) + return RecordResponse.from_summary(summary) diff --git a/server/osa/application/api/v1/routes/data/records_table.py b/server/osa/application/api/v1/routes/data/records_table.py new file mode 100644 index 0000000..9567d27 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/records_table.py @@ -0,0 +1,108 @@ +"""Records-table routes — ``/data/{schema}/records[.csv|.csv.gz]`` (US1 + US2). + +Endpoint *builders* capture a :class:`DataResponseFormat` by closure; the +generic :func:`register_table_routes` factory registers the GET/POST × format +matrix from them. GET streams (or paginates JSON) with no body; POST takes a +``FilterExpr`` body and is rate-limited per IP. Every route applies the +format's ``statement_timeout`` and resolves the schema (404 before bytes). +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from dishka.integrations.fastapi import FromDishka +from fastapi import APIRouter, Request +from fastapi.responses import StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter +from osa.application.api.v1.routes.data._params import FilterRequestBody, build_plan +from osa.application.api.v1.routes.data._runtime import apply_statement_timeout +from osa.application.api.v1.routes.data._streaming import build_table_response +from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.domain.data.model.format import DataResponseFormat +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import TableKind +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.service.data_query import DataQueryService +from osa.domain.shared.model.srn import SchemaId + + +async def _records_columns( + catalog: DataCatalogService, schema: str +) -> tuple[Sequence[ColumnSpec], SchemaId]: + """Resolve schema (404 if unknown/reserved) and return its records columns + SchemaId.""" + schema_id = await catalog.resolve_schema(schema) + manifest = await catalog.get_schema_manifest(schema_id) + records_resource = next(tr for tr in manifest.table_resources if tr.name == "records") + return records_resource.columns, schema_id + + +def _make_get_endpoint(fmt: DataResponseFormat): + async def endpoint( + schema: str, + query_service: FromDishka[DataQueryService], + catalog_service: FromDishka[DataCatalogService], + session: FromDishka[AsyncSession], + cursor: str | None = None, + limit: int = 50, + sort: str | None = None, + ) -> StreamingResponse: + columns, schema_id = await _records_columns(catalog_service, schema) + await apply_statement_timeout(session, fmt) + plan = build_plan( + schema_id=schema_id, + table_kind=TableKind.RECORDS, + feature_name=None, + filter_expr=None, + cursor=cursor, + limit=limit, + sort=sort, + ) + rows = query_service.stream_records(plan) + return await build_table_response(rows, fmt, columns, plan) + + return endpoint + + +def _make_post_endpoint(fmt: DataResponseFormat): + async def endpoint( + request: Request, + schema: str, + body: FilterRequestBody, + query_service: FromDishka[DataQueryService], + catalog_service: FromDishka[DataCatalogService], + session: FromDishka[AsyncSession], + ) -> StreamingResponse: + columns, schema_id = await _records_columns(catalog_service, schema) + await apply_statement_timeout(session, fmt) + plan = build_plan( + schema_id=schema_id, + table_kind=TableKind.RECORDS, + feature_name=None, + filter_expr=body.filter, + cursor=body.cursor, + limit=body.limit, + sort=body.sort, + ) + rows = query_service.stream_records(plan) + return await build_table_response(rows, fmt, columns, plan) + + # slowapi scopes a rate limit by the decorated function's ``module.__name__`` + # captured at decoration time. These builder closures are all named + # ``endpoint``, so name each uniquely BEFORE applying the limiter, or every + # POST route would collapse into a single shared limit bucket. + endpoint.__name__ = f"records_post_{format_key(fmt)}" + endpoint.__qualname__ = endpoint.__name__ + return limiter.limit(POST_RATE_LIMIT)(endpoint) + + +def register(router: APIRouter) -> None: + register_table_routes( + router, + "/{schema}/records", + _make_get_endpoint, + _make_post_endpoint, + "records", + ) diff --git a/server/osa/application/api/v1/routes/data/tables.py b/server/osa/application/api/v1/routes/data/tables.py new file mode 100644 index 0000000..9fcba32 --- /dev/null +++ b/server/osa/application/api/v1/routes/data/tables.py @@ -0,0 +1,89 @@ +"""Metaprogrammed table-route factory. + +One call to :func:`register_table_routes` registers the full GET/POST × format +matrix (6 routes) for a table-shaped resource — the records table and every +feature table share this exact shape. The caller supplies endpoint builders +(``make_get_endpoint`` / ``make_post_endpoint``) that capture the +:class:`DataResponseFormat` by closure and declare the FastAPI-visible +signature (path params + injected dependencies), because path params differ +between resources (``/{schema}/records`` vs ``/{schema}/{feature}``). + +Operation IDs are stable and match the OpenAPI contract: +``{resource}_{method}_{format_key}`` where ``format_key`` is ``json`` / ``csv`` +/ ``csv_gz``. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from fastapi import APIRouter + +from osa.domain.data.model.format import FORMATS, DataResponseFormat + +EndpointBuilder = Callable[[DataResponseFormat], Callable[..., Any]] + + +def format_key(fmt: DataResponseFormat) -> str: + """``""`` → ``json``; ``csv`` → ``csv``; ``csv.gz`` → ``csv_gz``.""" + return "json" if fmt.suffix == "" else fmt.suffix.replace(".", "_") + + +def _existing_operation_ids(router: APIRouter) -> set[str]: + """Operation IDs already registered on *router* (from prior factory calls).""" + return {op for r in router.routes if (op := getattr(r, "operation_id", None))} + + +def path_for(base_path: str, fmt: DataResponseFormat) -> str: + return base_path if fmt.suffix == "" else f"{base_path}.{fmt.suffix}" + + +def register_table_routes( + router: APIRouter, + base_path: str, + make_get_endpoint: EndpointBuilder, + make_post_endpoint: EndpointBuilder, + resource_name: str, +) -> None: + """Register GET + POST routes for every format under ``base_path``. + + Formats are registered longest-suffix-first so the bare (JSON) path is added + last. This matters when ``base_path`` ends in a path parameter (e.g. the + feature table's ``/{schema}/{feature}``): Starlette's ``{feature}`` matches + ``[^/]+``, which includes dots, so a bare ``/{schema}/{feature}`` registered + first would greedily capture ``feature.csv.gz`` as the feature name and the + suffixed routes would never match. Resources with a literal segment (records) + are unaffected by the ordering. + """ + # Seed with op IDs already on this router so a second factory call for a + # different resource (records + feature share one router) can't silently mint + # a colliding id — operation IDs must be globally unique for SDK codegen. + seen_ids = _existing_operation_ids(router) + for fmt in sorted(FORMATS, key=lambda f: len(f.suffix), reverse=True): + key = format_key(fmt) + path = path_for(base_path, fmt) + get_id = f"{resource_name}_get_{key}" + post_id = f"{resource_name}_post_{key}" + for op_id in (get_id, post_id): + if op_id in seen_ids: + raise ValueError( + f"Duplicate operation_id generated by the table factory: {op_id!r}. " + f"Resource name {resource_name!r} collides with a route already on this " + "router — operation IDs must be globally unique for SDK codegen." + ) + seen_ids.add(op_id) + router.add_api_route( + path, + make_get_endpoint(fmt), + methods=["GET"], + operation_id=get_id, + response_model=None, + ) + router.add_api_route( + path, + make_post_endpoint(fmt), + methods=["POST"], + operation_id=post_id, + response_model=None, + ) diff --git a/server/osa/application/api/v1/routes/discovery.py b/server/osa/application/api/v1/routes/discovery.py deleted file mode 100644 index fb4295a..0000000 --- a/server/osa/application/api/v1/routes/discovery.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Discovery API routes — search and filter records and features.""" - -from typing import Any - -from dishka.integrations.fastapi import DishkaRoute, FromDishka -from fastapi import APIRouter -from pydantic import BaseModel, Field - -from osa.domain.discovery.model.value import FilterExpr, SortOrder -from osa.domain.discovery.query.get_feature_catalog import ( - GetFeatureCatalog, - GetFeatureCatalogHandler, - GetFeatureCatalogResult, -) -from osa.domain.discovery.query.search_features import ( - SearchFeatures, - SearchFeaturesHandler, - SearchFeaturesResult, -) -from osa.domain.discovery.query.search_records import ( - SearchRecords, - SearchRecordsHandler, - SearchRecordsResult, -) -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import ConventionSRN, SchemaId - -router = APIRouter( - prefix="/discovery", - tags=["discovery"], - route_class=DishkaRoute, -) - - -# ── Request / Response models ── - - -class RecordSearchRequest(BaseModel): - schema: str | None = None - """Short-form schema identity: ``"@"`` (e.g. ``"pdb-structure@1.0.0"``).""" - - convention_srn: ConventionSRN | None = None - filter: FilterExpr | None = None - q: str | None = None - sort: str = "published_at" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = Field(default=20, ge=1, le=100) - - -class RecordSearchResponse(BaseModel): - results: list[dict[str, Any]] - cursor: str | None - has_more: bool - - -class FeatureCatalogResponse(BaseModel): - tables: list[dict[str, Any]] - - -class FeatureSearchRequest(BaseModel): - schema: str | None = None - """Short-form schema identity, optional. See RecordSearchRequest.schema.""" - - filter: FilterExpr | None = None - record_srn: str | None = None - sort: str = "id" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = Field(default=50, ge=1, le=100) - - -class FeatureSearchResponse(BaseModel): - rows: list[dict[str, Any]] - cursor: str | None - has_more: bool - - -def _parse_schema(value: str | None) -> SchemaId | None: - if value is None: - return None - if "@" not in value: - raise ValidationError( - f"Schema {value!r} must be fully qualified as '@' " - "(e.g. 'pdb-structure@1.0.0'). Family-level scoping " - "(id alone, resolving to the latest version across a schema family) " - "is planned but not yet supported.", - field="schema", - code="cross_scope_not_yet_supported", - ) - try: - return SchemaId.parse(value) - except ValueError as exc: - raise ValidationError(str(exc), field="schema") from exc - - -# ── Routes ── - - -@router.post("/records") -async def search_records( - body: RecordSearchRequest, - handler: FromDishka[SearchRecordsHandler], -) -> RecordSearchResponse: - """Search and filter published records.""" - result: SearchRecordsResult = await handler.run( - SearchRecords( - filter_expr=body.filter, - schema_id=_parse_schema(body.schema), - convention_srn=body.convention_srn, - q=body.q, - sort=body.sort, - order=body.order, - cursor=body.cursor, - limit=body.limit, - ) - ) - return RecordSearchResponse( - results=result.results, - cursor=result.cursor, - has_more=result.has_more, - ) - - -@router.get("/features") -async def get_feature_catalog( - handler: FromDishka[GetFeatureCatalogHandler], -) -> FeatureCatalogResponse: - """List available feature tables with column schemas and record counts.""" - result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) - return FeatureCatalogResponse(tables=result.tables) - - -@router.post("/features/{hook_name}") -async def search_features( - hook_name: str, - body: FeatureSearchRequest, - handler: FromDishka[SearchFeaturesHandler], -) -> FeatureSearchResponse: - """Query and filter rows in a specific feature table.""" - result: SearchFeaturesResult = await handler.run( - SearchFeatures( - hook_name=hook_name, - filter_expr=body.filter, - schema_id=_parse_schema(body.schema), - record_srn=body.record_srn, - sort=body.sort, - order=body.order, - cursor=body.cursor, - limit=body.limit, - ) - ) - return FeatureSearchResponse( - rows=result.rows, - cursor=result.cursor, - has_more=result.has_more, - ) diff --git a/server/osa/application/api/v1/routes/records.py b/server/osa/application/api/v1/routes/records.py deleted file mode 100644 index 4182005..0000000 --- a/server/osa/application/api/v1/routes/records.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Records API routes.""" - -from typing import Any - -from dishka.integrations.fastapi import DishkaRoute, FromDishka -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel - -from osa.domain.record.query.get_record import GetRecord, GetRecordHandler, RecordDetail -from osa.domain.shared.model.srn import RecordSRN - -router = APIRouter( - prefix="/records", - tags=["records"], - route_class=DishkaRoute, -) - - -class RecordResponse(BaseModel): - """Single record response.""" - - record: dict[str, Any] - - -@router.get("/{srn:path}") -async def get_record( - srn: str, - handler: FromDishka[GetRecordHandler], -) -> RecordResponse: - """Get a record by SRN.""" - try: - record_srn = RecordSRN.parse(srn) - except ValueError as e: - raise HTTPException(status_code=400, detail=f"Invalid SRN: {e}") - - result: RecordDetail = await handler.run(GetRecord(srn=record_srn)) - - return RecordResponse( - record={ - "srn": str(result.srn), - "source": result.source.model_dump(), - "convention_srn": str(result.convention_srn), - "metadata": result.metadata, - "published_at": result.published_at.isoformat(), - "features": result.features, - } - ) diff --git a/server/osa/application/api/v1/routes/search.py b/server/osa/application/api/v1/routes/search.py deleted file mode 100644 index c71008e..0000000 --- a/server/osa/application/api/v1/routes/search.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Search API routes.""" - -from typing import Any - -from dishka.integrations.fastapi import DishkaRoute, FromDishka -from fastapi import APIRouter, HTTPException, Query -from pydantic import BaseModel - -from osa.domain.index.model.registry import IndexRegistry -from osa.sdk.index import QueryResult - -router = APIRouter( - prefix="/search", - tags=["search"], - route_class=DishkaRoute, -) - - -class SearchResponse(BaseModel): - """Search response model.""" - - query: str - index: str - total: int - has_more: bool - results: list[dict[str, Any]] - - -@router.get("/{index_name}") -async def search_index( - index_name: str, - indexes: FromDishka[IndexRegistry], - q: str = Query(..., description="Search query"), - offset: int = Query(0, ge=0, description="Number of results to skip"), - limit: int = Query(20, ge=1, le=100, description="Maximum number of results"), -) -> SearchResponse: - """Search a specific index by name.""" - if not len(indexes): - raise HTTPException(status_code=503, detail="No search indexes configured") - - backend = indexes.get(index_name) - if not backend: - raise HTTPException( - status_code=404, - detail=f"Index '{index_name}' not found. Available: {indexes.names()}", - ) - - # Naive pagination: fetch offset+limit+1 to determine has_more, then slice - result: QueryResult = await backend.query(q, limit=offset + limit + 1) - all_hits = result.hits[offset:] # Skip first `offset` results - has_more = len(all_hits) > limit - hits = all_hits[:limit] - - return SearchResponse( - query=q, - index=index_name, - total=len(hits), - has_more=has_more, - results=[ - { - "srn": hit.srn, - "score": hit.score, - "metadata": hit.metadata, - } - for hit in hits - ], - ) - - -@router.get("/") -async def list_indexes( - indexes: FromDishka[IndexRegistry], -) -> dict[str, list[str]]: - """List available search indexes.""" - return {"indexes": indexes.names()} diff --git a/server/osa/application/api/v1/routes/stats.py b/server/osa/application/api/v1/routes/stats.py index 7fe4c2e..64185ad 100644 --- a/server/osa/application/api/v1/routes/stats.py +++ b/server/osa/application/api/v1/routes/stats.py @@ -4,7 +4,6 @@ from fastapi import APIRouter from pydantic import BaseModel -from osa.domain.index.model.registry import IndexRegistry from osa.domain.record.port.repository import RecordRepository router = APIRouter( @@ -14,38 +13,23 @@ ) -class IndexStats(BaseModel): - """Stats for a single index.""" - - name: str - count: int - healthy: bool - - class StatsResponse(BaseModel): - """System statistics response.""" + """System statistics response. + + The legacy ``indexes`` field was removed with the index domain (the unified + ``/data/`` surface replaces vector/keyword index reads). Per-schema and + per-feature-table row counts now live in each schema manifest at + ``GET /api/v1/data/{schema}``; ``data_url`` points there. + """ records: int - indexes: list[IndexStats] + data_url: str = "/api/v1/data" @router.get("") async def get_stats( record_repo: FromDishka[RecordRepository], - indexes: FromDishka[IndexRegistry], ) -> StatsResponse: """Get system statistics.""" record_count = await record_repo.count() - - index_stats = [] - for name, backend in indexes.items(): - try: - count = await backend.count() - healthy = await backend.health() - except Exception: - count = 0 - healthy = False - - index_stats.append(IndexStats(name=name, count=count, healthy=healthy)) - - return StatsResponse(records=record_count, indexes=index_stats) + return StatsResponse(records=record_count) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 5227635..6c22eab 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -5,8 +5,8 @@ from osa.config import Config from osa.domain.auth.util.di import AuthProvider +from osa.domain.data.util.di import DataProvider from osa.domain.deposition.util.di import DepositionProvider -from osa.domain.discovery.util.di import DiscoveryProvider from osa.domain.feature.util.di import FeatureProvider from osa.domain.metadata.util.di import MetadataProvider from osa.domain.semantics.util.di.provider import SemanticsProvider @@ -15,7 +15,6 @@ from osa.infrastructure.auth import AuthInfraProvider from osa.infrastructure.event.di import EventProvider from osa.infrastructure.http.di import HttpProvider -from osa.infrastructure.index.di import IndexProvider from osa.infrastructure.k8s.di import RunnerProvider from osa.infrastructure.persistence import PersistenceProvider from osa.infrastructure.ingest.di import IngestProvider @@ -44,7 +43,6 @@ def create_container( return make_async_container( PersistenceProvider(), RunnerProvider(), - IndexProvider(), IngestProvider(), EventProvider(extra_handlers=extra_handlers), HttpProvider(), @@ -55,7 +53,7 @@ def create_container( ValidationProvider(), AuthProvider(), AuthInfraProvider(), - DiscoveryProvider(), + DataProvider(), *extra_providers, context={Config: config, OSAPaths: paths}, scopes=Scope, # type: ignore[arg-type] # Custom scope class diff --git a/server/osa/domain/discovery/model/__init__.py b/server/osa/domain/data/__init__.py similarity index 100% rename from server/osa/domain/discovery/model/__init__.py rename to server/osa/domain/data/__init__.py diff --git a/server/osa/domain/discovery/port/__init__.py b/server/osa/domain/data/model/__init__.py similarity index 100% rename from server/osa/domain/discovery/port/__init__.py rename to server/osa/domain/data/model/__init__.py diff --git a/server/osa/domain/data/model/catalog.py b/server/osa/domain/data/model/catalog.py new file mode 100644 index 0000000..66f9d96 --- /dev/null +++ b/server/osa/domain/data/model/catalog.py @@ -0,0 +1,35 @@ +"""Node catalog response envelope. + +``GET /data`` returns the node's domain plus the published schemas, each with a +summary list of table resources (names + kinds only). Column-level detail comes +from the per-schema manifest at ``GET /data/{schema}``. +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from osa.domain.data.model.query_plan import TableKind + + +class TableResourceSummary(BaseModel): + """Name + kind of an addressable table resource (no columns/counts).""" + + name: str + kind: TableKind + + +class CatalogEntry(BaseModel): + """One published schema in the node catalog.""" + + id: str # short schema id + version: str # SemVer + srn: str # full schema SRN + table_resources: list[TableResourceSummary] + + +class NodeCatalog(BaseModel): + """The node's published-schema catalog. Empty ``schemas`` is valid (200).""" + + node_domain: str + schemas: list[CatalogEntry] diff --git a/server/osa/domain/discovery/model/value.py b/server/osa/domain/data/model/filter.py similarity index 55% rename from server/osa/domain/discovery/model/value.py rename to server/osa/domain/data/model/filter.py index 65549a0..bacea17 100644 --- a/server/osa/domain/discovery/model/value.py +++ b/server/osa/domain/data/model/filter.py @@ -1,28 +1,95 @@ -"""Discovery domain value objects — filters, cursors, result types. +"""Filter DSL for the ``/data/`` read surface. -Feature 076 replaces the flat ``Filter`` list with a compound ``FilterExpr`` -discriminated union (``And``/``Or``/``Not``/``Predicate``). Field references -inside predicates are typed (:class:`MetadataFieldRef` or -:class:`FeatureFieldRef`); the dotted wire form is parsed at the API boundary. +Relocated from the ``discovery`` domain (research §3 / FR-041): the +``FilterExpr`` discriminated union, its typed field references, and the +operator-compatibility tables are reused wholesale with no behaviour change. +During the discovery→data coexistence window (research §10) the ``data`` +domain keeps its own copy so it has no dependency on ``discovery`` once the +latter is deleted. + +Two kinds of field references are supported: + +- :class:`MetadataFieldRef` — resolves to a column in + ``metadata._v``. +- :class:`FeatureFieldRef` — resolves to a column in ``features.``. + +Wire format for a reference is a dotted path (``metadata.`` or +``features..``). :func:`parse_field_ref` parses the wire form +into a typed reference and validates identifier shape. """ from __future__ import annotations -import base64 -import json -from datetime import datetime +import re from enum import StrEnum from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, Field, model_validator -from osa.domain.discovery.model.refs import ( - FeatureFieldRef, - MetadataFieldRef, - parse_field_ref, -) from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.model.srn import RecordSRN + +_IDENT = re.compile(r"^[a-z][a-z0-9_]*$") + + +# ── Field references ── + + +class MetadataFieldRef(BaseModel): + path: Literal["metadata"] = "metadata" + field: str + + def dotted(self) -> str: + return f"metadata.{self.field}" + + +class FeatureFieldRef(BaseModel): + path: Literal["features"] = "features" + hook: str + column: str + + def dotted(self) -> str: + return f"features.{self.hook}.{self.column}" + + +def parse_field_ref(dotted: str) -> "FieldRef": + """Parse a dotted-path field reference into its typed form. + + Raises :class:`ValueError` when the path shape or identifier doesn't match + the documented grammar. + """ + if not isinstance(dotted, str): + raise ValueError(f"Expected dotted string, got {type(dotted).__name__}") + + parts = dotted.split(".") + if not parts: + raise ValueError(f"Empty field reference: {dotted!r}") + + head = parts[0] + if head == "metadata": + if len(parts) != 2: + raise ValueError(f"metadata.* refs must be exactly two dotted parts, got {dotted!r}") + field = parts[1] + if not _IDENT.match(field): + raise ValueError(f"Invalid metadata field identifier: {field!r}") + return MetadataFieldRef(field=field) + + if head == "features": + if len(parts) != 3: + raise ValueError(f"features.* refs must be exactly three dotted parts, got {dotted!r}") + hook, column = parts[1], parts[2] + if not _IDENT.match(hook): + raise ValueError(f"Invalid hook identifier: {hook!r}") + if not _IDENT.match(column): + raise ValueError(f"Invalid feature column identifier: {column!r}") + return FeatureFieldRef(hook=hook, column=column) + + raise ValueError( + f"Unknown field reference prefix {head!r} in {dotted!r}. " + "Expected 'metadata.' or 'features..'." + ) + + +# ── Operators ── class FilterOperator(StrEnum): @@ -51,6 +118,9 @@ class SortOrder(StrEnum): PredicateValue = Union[str, int, float, bool, list[str], list[float], None] +# ── Filter expression tree ── + + class Predicate(BaseModel): kind: Literal["predicate"] = "predicate" field: FieldRef @@ -88,13 +158,14 @@ class Not(BaseModel): Field(discriminator="kind"), ] -# Resolve forward references And.model_rebuild() Or.model_rebuild() Not.model_rebuild() -# Operators valid per column type for metadata/feature column validation. +# ── Operator compatibility tables ── + +# Operators valid per semantic column type for metadata/feature validation. VALID_OPERATORS: dict[FieldType, set[FilterOperator]] = { FieldType.TEXT: { FilterOperator.EQ, @@ -139,8 +210,8 @@ class Not(BaseModel): FieldType.BOOLEAN: {FilterOperator.EQ, FilterOperator.IS_NULL}, } -# Operators valid against raw JSON-schema primitive types (used for feature columns -# whose Column.json_type is a JSON Schema primitive rather than a semantic FieldType). +# Operators valid against raw JSON-schema primitive types (feature columns +# whose Column.json_type is a JSON Schema primitive rather than a FieldType). JSON_TYPE_OPERATORS: dict[str, set[FilterOperator]] = { "string": { FilterOperator.EQ, @@ -173,61 +244,3 @@ class Not(BaseModel): "array": {FilterOperator.EQ, FilterOperator.IS_NULL}, "object": {FilterOperator.EQ, FilterOperator.IS_NULL}, } - - -def encode_cursor(sort_value: Any, id_value: Any) -> str: - """Encode a cursor as base64 JSON.""" - payload = {"s": sort_value, "id": id_value} - return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() - - -def decode_cursor(cursor: str) -> dict[str, Any]: - """Decode a base64 JSON cursor. Raises ValueError on malformed input.""" - try: - raw = base64.urlsafe_b64decode(cursor.encode()) - data = json.loads(raw) - except Exception as exc: - raise ValueError(f"Malformed cursor: {exc}") from exc - if not isinstance(data, dict) or "s" not in data or "id" not in data: - raise ValueError("Cursor must contain 's' and 'id' keys") - return data - - -class RecordSummary(BaseModel): - srn: RecordSRN - published_at: datetime - metadata: dict[str, Any] - - -class RecordSearchResult(BaseModel): - results: list[RecordSummary] - cursor: str | None - has_more: bool - - -class ColumnInfo(BaseModel): - name: str - type: str - required: bool - - -class FeatureCatalogEntry(BaseModel): - hook_name: str - columns: list[ColumnInfo] - record_count: int - - -class FeatureCatalog(BaseModel): - tables: list[FeatureCatalogEntry] - - -class FeatureRow(BaseModel): - row_id: int - record_srn: RecordSRN - data: dict[str, Any] - - -class FeatureSearchResult(BaseModel): - rows: list[FeatureRow] - cursor: str | None - has_more: bool diff --git a/server/osa/domain/data/model/format.py b/server/osa/domain/data/model/format.py new file mode 100644 index 0000000..276d74c --- /dev/null +++ b/server/osa/domain/data/model/format.py @@ -0,0 +1,54 @@ +"""Response-format registry for the ``/data/`` surface. + +Each :class:`DataResponseFormat` ties a URL suffix to its serializer, media +type, pagination semantics, and per-route statement timeout (research §7). +Adding a format (e.g. ``ndjson``, ``parquet``) is a one-line append plus one +serializer class — the route factory and timeout dependency pick it up +automatically. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from osa.domain.data.serializer.csv import CsvSerializer +from osa.domain.data.serializer.csv_gzip import CsvGzipSerializer +from osa.domain.data.serializer.json import JsonSerializer +from osa.domain.data.serializer.protocol import Serializer + + +@dataclass(frozen=True) +class DataResponseFormat: + serializer_cls: type[Serializer] + paginated: bool # True → JSON envelope + cursor; False → unbounded stream + suffix: str # "" (json), "csv", "csv.gz" + media_type: str + statement_timeout: str # "30s" paginated, "30min" streaming + + def make_serializer(self) -> Serializer: + return self.serializer_cls() + + +FORMATS: tuple[DataResponseFormat, ...] = ( + DataResponseFormat( + serializer_cls=JsonSerializer, + paginated=True, + suffix="", + media_type="application/json", + statement_timeout="30s", + ), + DataResponseFormat( + serializer_cls=CsvSerializer, + paginated=False, + suffix="csv", + media_type="text/csv", + statement_timeout="30min", + ), + DataResponseFormat( + serializer_cls=CsvGzipSerializer, + paginated=False, + suffix="csv.gz", + media_type="application/gzip", + statement_timeout="30min", + ), +) diff --git a/server/osa/domain/data/model/manifest.py b/server/osa/domain/data/model/manifest.py new file mode 100644 index 0000000..30046a7 --- /dev/null +++ b/server/osa/domain/data/model/manifest.py @@ -0,0 +1,70 @@ +"""Schema manifest response envelope (FR-002, research §9). + +A stable, machine-readable cross-schema shape suitable for SDK code +generation and agent affordance: typed field definitions plus the set of +addressable table resources (the ``records`` table and every feature table), +each with its column schema, row count, and the URL-exposed format suffixes. +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from osa.domain.data.model.query_plan import TableKind +from osa.domain.semantics.model.value import FieldType + + +class FieldSpec(BaseModel): + """A schema-declared metadata field.""" + + name: str + type: FieldType + ontology_id: str | None = None # required iff type == TERM + ontology_version: str | None = None # required iff type == TERM + + +class ColumnSpec(BaseModel): + """A physical column on an addressable table resource.""" + + name: str + type: FieldType + + +# Implicit columns present on every records-table response, in wire order, +# ahead of the schema's declared metadata fields (data-model.md §ColumnSpec). +IMPLICIT_RECORD_COLUMN_SPECS: tuple[ColumnSpec, ...] = ( + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="srn", type=FieldType.TEXT), + ColumnSpec(name="schema_id", type=FieldType.TEXT), + ColumnSpec(name="version", type=FieldType.NUMBER), + ColumnSpec(name="created_at", type=FieldType.DATE), +) + +# Implicit columns present on every feature-table response, in wire order, ahead +# of the hook's declared output columns. Feature rows have no SRN in v1 (data- +# model.md §RecordSummary); ``record_srn`` is the join key back to ``records``. +IMPLICIT_FEATURE_COLUMN_SPECS: tuple[ColumnSpec, ...] = ( + ColumnSpec(name="id", type=FieldType.NUMBER), + ColumnSpec(name="record_srn", type=FieldType.TEXT), + ColumnSpec(name="created_at", type=FieldType.DATE), +) + + +class TableResource(BaseModel): + """One addressable table under a schema: the records table or a feature table.""" + + name: str # "records" or the feature table name + kind: TableKind + columns: list[ColumnSpec] + row_count: int + formats: list[str] # URL suffixes, e.g. ["", "csv", "csv.gz"] + + +class SchemaManifest(BaseModel): + """Full machine-readable manifest for a single schema version.""" + + id: str # short schema id, e.g. "compound" + version: str # SemVer, e.g. "1.0.0" + srn: str # full schema SRN + fields: list[FieldSpec] + table_resources: list[TableResource] diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py new file mode 100644 index 0000000..5a54adc --- /dev/null +++ b/server/osa/domain/data/model/query_plan.py @@ -0,0 +1,105 @@ +"""Query IR for the ``/data/`` read surface. + +``QueryPlan`` is the internal intermediate representation produced by both URL +parsing and POST-body parsing and consumed by the read engine. It is the pure +query description; the *output format* (which serializer, which statement +timeout) is intentionally NOT part of the plan — it is a route/serialization +concern carried alongside the plan, which also avoids a query_plan→format→ +serializer→manifest→query_plan import cycle. + +Also hosts the opaque cursor codec relocated from the discovery domain +(research §3 / FR-041) with no behaviour change. +""" + +from __future__ import annotations + +import base64 +import json +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +from osa.domain.data.model.filter import FilterExpr +from osa.domain.shared.model.ids import HookName +from osa.domain.shared.model.srn import SchemaId + + +class TableKind(StrEnum): + RECORDS = "records" + FEATURE = "feature" + + +class SortDirection(StrEnum): + ASC = "asc" + DESC = "desc" + + +class SortSpec(BaseModel): + """A single sort key — column plus direction (no bare tuples at boundaries).""" + + column: str + direction: SortDirection + + +class PaginationCursor(BaseModel): + """Opaque base64 wrapper around the last row's ``(sort_value, id)`` pair.""" + + value: str + + def __str__(self) -> str: + return self.value + + +class PaginationParams(BaseModel): + cursor: PaginationCursor | None = None + limit: int = Field(default=50, ge=1, le=1000) + + +# Default sort keys per table kind (data-model.md §PaginationParams). +_DEFAULT_SORTS: dict[TableKind, list[SortSpec]] = { + TableKind.RECORDS: [ + SortSpec(column="created_at", direction=SortDirection.DESC), + SortSpec(column="id", direction=SortDirection.DESC), + ], + TableKind.FEATURE: [SortSpec(column="id", direction=SortDirection.ASC)], +} + + +class QueryPlan(BaseModel): + schema_id: SchemaId + table_kind: TableKind + feature_name: HookName | None = None + filter: FilterExpr | None = None + pagination: PaginationParams = Field(default_factory=PaginationParams) + sort: list[SortSpec] = Field(default_factory=list) + + @model_validator(mode="after") + def _validate_and_default(self) -> "QueryPlan": + # feature_name present iff FEATURE + if self.table_kind == TableKind.FEATURE and self.feature_name is None: + raise ValueError("feature_name is required when table_kind is FEATURE") + if self.table_kind == TableKind.RECORDS and self.feature_name is not None: + raise ValueError("feature_name must be None when table_kind is RECORDS") + # Apply default sort per table kind when none was supplied. + if not self.sort: + self.sort = list(_DEFAULT_SORTS[self.table_kind]) + return self + + +def encode_cursor(sort_value: Any, id_value: Any) -> str: + """Encode a cursor as urlsafe base64 of ``{"s": sort_value, "id": id_value}``.""" + payload = {"s": sort_value, "id": id_value} + return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() + + +def decode_cursor(cursor: str) -> dict[str, Any]: + """Decode a base64 JSON cursor. Raises ``ValueError`` on malformed input.""" + try: + raw = base64.urlsafe_b64decode(cursor.encode()) + data = json.loads(raw) + except Exception as exc: + raise ValueError(f"Malformed cursor: {exc}") from exc + if not isinstance(data, dict) or "s" not in data or "id" not in data: + raise ValueError("Cursor must contain 's' and 'id' keys") + return data diff --git a/server/osa/domain/data/model/record_summary.py b/server/osa/domain/data/model/record_summary.py new file mode 100644 index 0000000..cae4507 --- /dev/null +++ b/server/osa/domain/data/model/record_summary.py @@ -0,0 +1,50 @@ +"""Row types yielded by the read engine. + +``RecordSummary`` is the records-table row. Per data-model.md every records +response carries BOTH the bare internal ``id`` and the full ``srn``; ``id`` and +``version`` are derivable from the SRN, ``schema_id`` and ``created_at`` from +the ``records`` table columns. Feature-table rows do not have SRNs in v1 — they +flow through the engine as plain column→value mappings, not ``RecordSummary``. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel + +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import RecordSRN, SchemaId + +# Implicit columns present on every records-table response, in wire order, +# ahead of the schema's declared metadata fields. +IMPLICIT_RECORD_COLUMNS: tuple[str, ...] = ("id", "srn", "schema_id", "version", "created_at") + + +class RecordSummary(BaseModel): + """A single published record as projected by the read engine.""" + + id: RecordId + srn: RecordSRN + schema_id: SchemaId + version: int + metadata: dict[str, Any] + created_at: datetime + + def flatten(self) -> dict[str, Any]: + """Flatten into a column→value mapping for serialization. + + Implicit columns come first (in :data:`IMPLICIT_RECORD_COLUMNS` order), + then the metadata fields. Metadata keys never shadow implicit columns + because schema field names cannot collide with the reserved implicit + set at registration time. + """ + return { + "id": str(self.id), + "srn": str(self.srn), + "schema_id": self.schema_id.render(), + "version": self.version, + "created_at": self.created_at.isoformat(), + **self.metadata, + } diff --git a/server/osa/domain/discovery/query/__init__.py b/server/osa/domain/data/port/__init__.py similarity index 100% rename from server/osa/domain/discovery/query/__init__.py rename to server/osa/domain/data/port/__init__.py diff --git a/server/osa/domain/data/port/data_read_store.py b/server/osa/domain/data/port/data_read_store.py new file mode 100644 index 0000000..60a9c2d --- /dev/null +++ b/server/osa/domain/data/port/data_read_store.py @@ -0,0 +1,45 @@ +"""DataReadStore port — read-only access feeding every ``/data/`` format. + +``stream_rows`` is the single streaming primitive behind both the paginated +JSON path and the unbounded CSV/CSV.gz path; the concrete adapter backs it with +a Postgres server-side cursor (research §2) so memory stays bounded regardless +of result size. Rows are yielded as column→value mappings (already projected): +records-table rows include the implicit ``id``/``srn``/``schema_id``/ +``version``/``created_at`` columns plus metadata fields; feature-table rows +carry the hook's declared columns. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from osa.domain.data.model.catalog import NodeCatalog + from osa.domain.data.model.manifest import SchemaManifest + from osa.domain.data.model.query_plan import QueryPlan + from osa.domain.data.model.record_summary import RecordSummary + from osa.domain.shared.model.ids import RecordId + from osa.domain.shared.model.srn import SchemaId + + +class DataReadStore(Protocol): + def stream_rows(self, plan: "QueryPlan") -> AsyncIterator[Mapping[str, Any]]: + """Stream projected rows for the plan via a server-side cursor.""" + ... + + async def get_record_by_id(self, id: "RecordId", version: int | None) -> "RecordSummary | None": + """Resolve a single record by bare ID (schema resolved via PK). ``None`` if absent.""" + ... + + async def get_node_catalog(self) -> "NodeCatalog": + """List published schemas with summary table resources.""" + ... + + async def get_schema_manifest(self, schema_id: "SchemaId") -> "SchemaManifest | None": + """Full manifest for a schema. ``None`` if unknown.""" + ... + + async def get_latest_schema_id(self, schema_short_id: str) -> "SchemaId | None": + """Resolve a bare schema id to its latest published version. ``None`` if unknown.""" + ... diff --git a/server/osa/domain/discovery/service/__init__.py b/server/osa/domain/data/query/__init__.py similarity index 100% rename from server/osa/domain/discovery/service/__init__.py rename to server/osa/domain/data/query/__init__.py diff --git a/server/osa/domain/discovery/util/__init__.py b/server/osa/domain/data/serializer/__init__.py similarity index 100% rename from server/osa/domain/discovery/util/__init__.py rename to server/osa/domain/data/serializer/__init__.py diff --git a/server/osa/domain/data/serializer/csv.py b/server/osa/domain/data/serializer/csv.py new file mode 100644 index 0000000..710fa23 --- /dev/null +++ b/server/osa/domain/data/serializer/csv.py @@ -0,0 +1,61 @@ +"""CSV serializer — header row from columns, then one row per record. + +Streams incrementally via a reusable :class:`_RowEncoder` that holds a single +``io.StringIO`` + ``csv.writer``, so only one row is ever materialised at a +time. An empty result still yields the header row followed by EOF. Quoting uses +``csv.QUOTE_MINIMAL``. +""" + +from __future__ import annotations + +import csv +import io +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar + + +from osa.domain.data.model.manifest import ColumnSpec + + +class _RowEncoder: + """Encodes one CSV row at a time, reusing a single buffer + writer. + + The writer's type is inferred from the ``csv.writer(...)`` assignment, so + no untyped parameter is threaded through the serializer. + """ + + def __init__(self) -> None: + self._buf = io.StringIO() + self._writer = csv.writer(self._buf, quoting=csv.QUOTE_MINIMAL) + + def encode(self, values: Sequence[Any]) -> bytes: + self._buf.seek(0) + self._buf.truncate(0) + self._writer.writerow(values) + return self._buf.getvalue().encode() + + +class CsvSerializer: + media_type: ClassVar[str] = "text/csv" + + async def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + names = [col.name for col in columns] + encoder = _RowEncoder() + + yield encoder.encode(names) + + async for row in rows: + yield encoder.encode([_stringify(row.get(name)) for name in names]) + + +def _stringify(value: Any) -> Any: + """Render non-scalar values deterministically; let csv handle scalars/None.""" + if value is None or isinstance(value, (str, int, float, bool)): + return value + return str(value) diff --git a/server/osa/domain/data/serializer/csv_gzip.py b/server/osa/domain/data/serializer/csv_gzip.py new file mode 100644 index 0000000..f6534d7 --- /dev/null +++ b/server/osa/domain/data/serializer/csv_gzip.py @@ -0,0 +1,41 @@ +"""Gzip-while-streaming CSV serializer (research §1). + +Wraps :class:`CsvSerializer` through ``zlib.compressobj(level=6, +wbits=MAX_WBITS|16)``. The ``MAX_WBITS|16`` flag selects gzip-stream mode +(proper gzip header + trailer) so the byte stream is exactly what ``gunzip`` +and HTTP gzip clients expect. Memory footprint is the 32KB DEFLATE window plus +one CSV row, regardless of result size — the basis of the SC-001 bounded-memory +target. +""" + +from __future__ import annotations + +import zlib +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar + +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.serializer.csv import CsvSerializer + + +class CsvGzipSerializer: + media_type: ClassVar[str] = "application/gzip" + + def __init__(self) -> None: + self._csv = CsvSerializer() + + async def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + compressor = zlib.compressobj(level=6, wbits=zlib.MAX_WBITS | 16) + async for chunk in self._csv.stream(rows, columns): + compressed = compressor.compress(chunk) + if compressed: + yield compressed + tail = compressor.flush(zlib.Z_FINISH) + if tail: + yield tail diff --git a/server/osa/domain/data/serializer/json.py b/server/osa/domain/data/serializer/json.py new file mode 100644 index 0000000..43ff0fb --- /dev/null +++ b/server/osa/domain/data/serializer/json.py @@ -0,0 +1,38 @@ +"""JSON serializer — paginated envelope ``{"rows": [...], "next_cursor": ...}``. + +Built row-by-row so the whole page is never held twice in memory. The +``next_cursor`` is precomputed by the query service and passed in; an empty +result yields ``{"rows": [], "next_cursor": null}`` with HTTP 200. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar + +from osa.domain.data.model.manifest import ColumnSpec + + +class JsonSerializer: + media_type: ClassVar[str] = "application/json" + + async def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + yield b'{"rows":[' + first = True + async for row in rows: + projected = {col.name: row.get(col.name) for col in columns} + chunk = json.dumps(projected, default=str).encode() + if first: + first = False + else: + chunk = b"," + chunk + yield chunk + cursor_json = json.dumps(next_cursor).encode() + yield b'],"next_cursor":' + cursor_json + b"}" diff --git a/server/osa/domain/data/serializer/protocol.py b/server/osa/domain/data/serializer/protocol.py new file mode 100644 index 0000000..bfcffcf --- /dev/null +++ b/server/osa/domain/data/serializer/protocol.py @@ -0,0 +1,34 @@ +"""Serializer protocol — rows in, bytes out. + +Serializers are stateless and have no I/O dependencies beyond stdlib (``json``, +``csv``, ``zlib``). They consume an async iterator of already-projected rows +(column→value mappings) and the column schema that fixes wire order, and yield +response bytes incrementally so streaming formats stay memory-bounded. + +Note on ``next_cursor``: cursor derivation lives in the query service (matching +the proven discovery engine), not in the serializer. Paginated serializers +(JSON) receive the precomputed ``next_cursor`` to embed in the envelope; +streaming serializers (CSV, CSV.gz) ignore it. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping, Sequence +from typing import Any, ClassVar, Protocol, runtime_checkable + +from osa.domain.data.model.manifest import ColumnSpec + + +@runtime_checkable +class Serializer(Protocol): + media_type: ClassVar[str] + + def stream( + self, + rows: AsyncIterator[Mapping[str, Any]], + columns: Sequence[ColumnSpec], + *, + next_cursor: str | None = None, + ) -> AsyncIterator[bytes]: + """Render ``rows`` as response bytes, yielded incrementally.""" + ... diff --git a/server/osa/domain/export/__init__.py b/server/osa/domain/data/service/__init__.py similarity index 100% rename from server/osa/domain/export/__init__.py rename to server/osa/domain/data/service/__init__.py diff --git a/server/osa/domain/data/service/data_catalog.py b/server/osa/domain/data/service/data_catalog.py new file mode 100644 index 0000000..ba2fb49 --- /dev/null +++ b/server/osa/domain/data/service/data_catalog.py @@ -0,0 +1,75 @@ +"""DataCatalogService — catalog, manifest, and single-record-by-ID reads. + +Read-only orchestration over the :class:`DataReadStore` port. Catalog + manifest +(US3) and record-by-ID (US4). The reserved-name 404 is enforced defensively +here in addition to the write-side aggregate invariants (research §6). +""" + +from __future__ import annotations + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import SchemaManifest +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.reserved import RESERVED_NAMES +from osa.domain.shared.model.srn import SchemaId +from osa.domain.shared.service import Service + + +class DataCatalogService(Service): + read_store: DataReadStore + + async def resolve_schema(self, raw: str) -> SchemaId: + """Resolve a URL schema segment (```` or ``@``) to a SchemaId. + + A bare id resolves to the latest published version. Reserved names and + unknown schemas raise ``NotFoundError`` (404 at the route). + """ + short_id = raw.split("@", 1)[0] + if short_id in RESERVED_NAMES: + raise NotFoundError( + f"The path '{short_id}' is reserved under /data/.", code="reserved_name" + ) + if "@" in raw: + try: + return SchemaId.parse(raw) + except ValueError as exc: + raise NotFoundError( + f"No schema '{raw}'. See /api/v1/data for the catalog.", + code="schema_not_found", + ) from exc + resolved = await self.read_store.get_latest_schema_id(short_id) + if resolved is None: + raise NotFoundError( + f"No schema '{raw}'. See /api/v1/data for the catalog.", + code="schema_not_found", + ) + return resolved + + async def get_node_catalog(self) -> NodeCatalog: + return await self.read_store.get_node_catalog() + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest: + # Defense in depth: a reserved id can't be registered (write-side + # invariant) but the read side handles the URL slot deliberately. + if schema_id.id.root in RESERVED_NAMES: + raise NotFoundError( + f"The path '{schema_id.id.root}' is reserved under /data/.", + code="reserved_name", + ) + manifest = await self.read_store.get_schema_manifest(schema_id) + if manifest is None: + raise NotFoundError( + f"No schema '{schema_id.render()}'. See /api/v1/data for the catalog.", + code="schema_not_found", + ) + return manifest + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary: + record = await self.read_store.get_record_by_id(id, version) + if record is None: + suffix = f"@{version}" if version is not None else "" + raise NotFoundError(f"No record with id '{id}{suffix}'.", code="record_not_found") + return record diff --git a/server/osa/domain/data/service/data_query.py b/server/osa/domain/data/service/data_query.py new file mode 100644 index 0000000..520a562 --- /dev/null +++ b/server/osa/domain/data/service/data_query.py @@ -0,0 +1,96 @@ +"""DataQueryService — streaming read business logic for records and features. + +Validates the filter tree's bounds (depth, predicate count, distinct feature +joins) against config, then delegates row streaming to the +:class:`DataReadStore` port. The returned async iterators are wrapped by the +route layer in a serializer + ``StreamingResponse``. Validation runs at the top +of the generator body, so it surfaces on the route's pre-flight ``__anext__`` +pull — before any response bytes (research §4). + +Field/operator-compatibility validation against the resolved table's columns is +performed by the read-store adapter during SQL compilation (where the column +types are known); that path raises ``ValidationError`` before the first row. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Iterator, Mapping +from typing import Any + +from osa.config import Config +from osa.domain.data.model.filter import And, FeatureFieldRef, FilterExpr, Not, Or, Predicate +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.shared.error import ValidationError +from osa.domain.shared.service import Service + + +class DataQueryService(Service): + read_store: DataReadStore + config: Config + + async def stream_records(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + if plan.table_kind != TableKind.RECORDS: + raise ValidationError("stream_records requires a RECORDS plan", field="table_kind") + self._validate_filter_bounds(plan.filter) + async for row in self.read_store.stream_rows(plan): + yield row + + async def stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + if plan.table_kind != TableKind.FEATURE: + raise ValidationError("stream_features requires a FEATURE plan", field="table_kind") + self._validate_filter_bounds(plan.filter) + async for row in self.read_store.stream_rows(plan): + yield row + + # ------------------------------------------------------------------ # + # Filter-tree bounds (ported from DiscoveryService, config-driven) + # ------------------------------------------------------------------ # + + def _validate_filter_bounds(self, expr: FilterExpr | None) -> None: + if expr is None: + return + depth = _tree_depth(expr) + if depth > self.config.discovery_max_filter_depth: + raise ValidationError( + f"Filter tree depth {depth} exceeds maximum " + f"{self.config.discovery_max_filter_depth}.", + field="filter", + code="filter_depth_exceeded", + ) + predicates = list(_iter_predicates(expr)) + if len(predicates) > self.config.discovery_max_predicates: + raise ValidationError( + f"Filter tree has {len(predicates)} predicates, exceeds maximum " + f"{self.config.discovery_max_predicates}.", + field="filter", + code="filter_predicates_exceeded", + ) + distinct_hooks = {p.field.hook for p in predicates if isinstance(p.field, FeatureFieldRef)} + if len(distinct_hooks) > self.config.discovery_max_cross_domain_joins: + raise ValidationError( + f"Filter joins {len(distinct_hooks)} feature hooks, exceeds maximum " + f"{self.config.discovery_max_cross_domain_joins}.", + field="filter", + code="filter_joins_exceeded", + ) + + +def _tree_depth(expr: FilterExpr) -> int: + if isinstance(expr, Predicate): + return 1 + if isinstance(expr, Not): + return 1 + _tree_depth(expr.operand) + if isinstance(expr, (And, Or)): + return 1 + max(_tree_depth(op) for op in expr.operands) + return 1 + + +def _iter_predicates(expr: FilterExpr) -> Iterator[Predicate]: + if isinstance(expr, Predicate): + yield expr + elif isinstance(expr, Not): + yield from _iter_predicates(expr.operand) + elif isinstance(expr, (And, Or)): + for op in expr.operands: + yield from _iter_predicates(op) diff --git a/server/osa/domain/export/adapter/__init__.py b/server/osa/domain/data/util/__init__.py similarity index 100% rename from server/osa/domain/export/adapter/__init__.py rename to server/osa/domain/data/util/__init__.py diff --git a/server/osa/domain/data/util/di/__init__.py b/server/osa/domain/data/util/di/__init__.py new file mode 100644 index 0000000..1fd666e --- /dev/null +++ b/server/osa/domain/data/util/di/__init__.py @@ -0,0 +1,3 @@ +from osa.domain.data.util.di.provider import DataProvider + +__all__ = ["DataProvider"] diff --git a/server/osa/domain/data/util/di/provider.py b/server/osa/domain/data/util/di/provider.py new file mode 100644 index 0000000..3230397 --- /dev/null +++ b/server/osa/domain/data/util/di/provider.py @@ -0,0 +1,20 @@ +"""Dishka DI provider for the data domain (services).""" + +from dishka import provide + +from osa.config import Config +from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.service.data_query import DataQueryService +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + + +class DataProvider(Provider): + @provide(scope=Scope.UOW) + def get_data_query_service(self, read_store: DataReadStore, config: Config) -> DataQueryService: + return DataQueryService(read_store=read_store, config=config) + + @provide(scope=Scope.UOW) + def get_data_catalog_service(self, read_store: DataReadStore) -> DataCatalogService: + return DataCatalogService(read_store=read_store) diff --git a/server/osa/domain/discovery/__init__.py b/server/osa/domain/discovery/__init__.py deleted file mode 100644 index c6d5d02..0000000 --- a/server/osa/domain/discovery/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Discovery domain — read-only search and filter API for records and features.""" diff --git a/server/osa/domain/discovery/model/refs.py b/server/osa/domain/discovery/model/refs.py deleted file mode 100644 index f5783a0..0000000 --- a/server/osa/domain/discovery/model/refs.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Typed field references used inside Predicate.field. - -Two kinds of references are supported: - -- :class:`MetadataFieldRef` — resolves to a column in ``metadata._v``. -- :class:`FeatureFieldRef` — resolves to a column in ``features.``. - -Wire format is a dotted path (``metadata.`` or -``features..``). :func:`parse_field_ref` parses the wire form -into a typed reference and validates identifier shape. -""" - -from __future__ import annotations - -import re -from typing import Literal, Union - -from pydantic import BaseModel - -_IDENT = re.compile(r"^[a-z][a-z0-9_]*$") - - -class MetadataFieldRef(BaseModel): - path: Literal["metadata"] = "metadata" - field: str - - def dotted(self) -> str: - return f"metadata.{self.field}" - - -class FeatureFieldRef(BaseModel): - path: Literal["features"] = "features" - hook: str - column: str - - def dotted(self) -> str: - return f"features.{self.hook}.{self.column}" - - -FieldRef = Union[MetadataFieldRef, FeatureFieldRef] - - -def parse_field_ref(dotted: str) -> FieldRef: - """Parse a dotted-path field reference into its typed form. - - Raises :class:`ValueError` when the path shape or identifier doesn't match - the documented grammar. - """ - if not isinstance(dotted, str): - raise ValueError(f"Expected dotted string, got {type(dotted).__name__}") - - parts = dotted.split(".") - if not parts: - raise ValueError(f"Empty field reference: {dotted!r}") - - head = parts[0] - if head == "metadata": - if len(parts) != 2: - raise ValueError(f"metadata.* refs must be exactly two dotted parts, got {dotted!r}") - field = parts[1] - if not _IDENT.match(field): - raise ValueError(f"Invalid metadata field identifier: {field!r}") - return MetadataFieldRef(field=field) - - if head == "features": - if len(parts) != 3: - raise ValueError(f"features.* refs must be exactly three dotted parts, got {dotted!r}") - hook, column = parts[1], parts[2] - if not _IDENT.match(hook): - raise ValueError(f"Invalid hook identifier: {hook!r}") - if not _IDENT.match(column): - raise ValueError(f"Invalid feature column identifier: {column!r}") - return FeatureFieldRef(hook=hook, column=column) - - raise ValueError( - f"Unknown field reference prefix {head!r} in {dotted!r}. " - "Expected 'metadata.' or 'features..'." - ) diff --git a/server/osa/domain/discovery/port/field_definition_reader.py b/server/osa/domain/discovery/port/field_definition_reader.py deleted file mode 100644 index c2a6bfa..0000000 --- a/server/osa/domain/discovery/port/field_definition_reader.py +++ /dev/null @@ -1,28 +0,0 @@ -"""FieldDefinitionReader port — cross-domain read port for schema field lookups.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from osa.domain.semantics.model.value import FieldType - from osa.domain.shared.model.srn import SchemaId - - -class FieldDefinitionReader(Protocol): - async def get_all_field_types(self) -> dict[str, FieldType]: - """Return global field_name -> FieldType map across all schemas. - - Raises ValidationError if same field name has conflicting types across schemas. - """ - ... - - async def get_fields_for_schema(self, schema_id: "SchemaId") -> dict[str, FieldType]: - """Return field_name -> FieldType for a specific schema's current major version. - - Returns an empty dict when the schema is unknown to the node. Callers - that treat "unknown schema" as an error condition must check for an - empty map and raise ``NotFoundError`` themselves — the port stays - neutral so that non-user-facing callers can handle absence explicitly. - """ - ... diff --git a/server/osa/domain/discovery/port/read_store.py b/server/osa/domain/discovery/port/read_store.py deleted file mode 100644 index 762364e..0000000 --- a/server/osa/domain/discovery/port/read_store.py +++ /dev/null @@ -1,50 +0,0 @@ -"""DiscoveryReadStore port — read-only access to records, features, metadata.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Protocol - -if TYPE_CHECKING: - from osa.domain.discovery.model.value import ( - FeatureCatalogEntry, - FeatureRow, - FilterExpr, - RecordSummary, - SortOrder, - ) - from osa.domain.semantics.model.value import FieldType - from osa.domain.shared.model.srn import ConventionSRN, RecordSRN, SchemaId - - -class DiscoveryReadStore(Protocol): - async def search_records( - self, - filter_expr: "FilterExpr | None", - schema_id: "SchemaId | None", - convention_srn: "ConventionSRN | None", - text_fields: list[str], - q: str | None, - sort: str, - order: "SortOrder", - cursor: dict | None, - limit: int, - field_types: "dict[str, FieldType] | None" = None, - ) -> "list[RecordSummary]": - """Search published records with a compound filter.""" - ... - - async def get_feature_catalog(self) -> "list[FeatureCatalogEntry]": ... - - async def get_feature_table_schema(self, hook_name: str) -> "FeatureCatalogEntry | None": ... - - async def search_features( - self, - hook_name: str, - filter_expr: "FilterExpr | None", - schema_id: "SchemaId | None", - record_srn: "RecordSRN | None", - sort: str, - order: "SortOrder", - cursor: dict | None, - limit: int, - ) -> "list[FeatureRow]": ... diff --git a/server/osa/domain/discovery/query/get_feature_catalog.py b/server/osa/domain/discovery/query/get_feature_catalog.py deleted file mode 100644 index 30f891d..0000000 --- a/server/osa/domain/discovery/query/get_feature_catalog.py +++ /dev/null @@ -1,32 +0,0 @@ -"""GetFeatureCatalog query — list available feature tables with schemas and counts.""" - -from osa.domain.discovery.model.value import FeatureCatalog -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.authorization.gate import public -from osa.domain.shared.query import Query, QueryHandler, Result - - -class GetFeatureCatalog(Query): - pass - - -class GetFeatureCatalogResult(Result): - tables: list[dict] - - -class GetFeatureCatalogHandler(QueryHandler[GetFeatureCatalog, GetFeatureCatalogResult]): - __auth__ = public() - discovery_service: DiscoveryService - - async def run(self, cmd: GetFeatureCatalog) -> GetFeatureCatalogResult: - catalog: FeatureCatalog = await self.discovery_service.get_feature_catalog() - return GetFeatureCatalogResult( - tables=[ - { - "hook_name": entry.hook_name, - "columns": [c.model_dump() for c in entry.columns], - "record_count": entry.record_count, - } - for entry in catalog.tables - ] - ) diff --git a/server/osa/domain/discovery/query/search_features.py b/server/osa/domain/discovery/query/search_features.py deleted file mode 100644 index 42dde9a..0000000 --- a/server/osa/domain/discovery/query/search_features.py +++ /dev/null @@ -1,57 +0,0 @@ -"""SearchFeatures query — query and filter rows in a specific feature table.""" - -from osa.domain.discovery.model.value import ( - FeatureSearchResult, - FilterExpr, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.authorization.gate import public -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.domain.shared.query import Query, QueryHandler, Result - - -class SearchFeatures(Query): - hook_name: str - filter_expr: FilterExpr | None = None - schema_id: SchemaId | None = None - record_srn: str | None = None - sort: str = "id" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = 50 - - -class SearchFeaturesResult(Result): - rows: list[dict] - cursor: str | None - has_more: bool - - -class SearchFeaturesHandler(QueryHandler[SearchFeatures, SearchFeaturesResult]): - __auth__ = public() - discovery_service: DiscoveryService - - async def run(self, cmd: SearchFeatures) -> SearchFeaturesResult: - record_srn: RecordSRN | None = None - if cmd.record_srn: - try: - record_srn = RecordSRN.parse(cmd.record_srn) - except ValueError as exc: - raise ValidationError(str(exc), field="record_srn") from exc - result: FeatureSearchResult = await self.discovery_service.search_features( - hook_name=cmd.hook_name, - filter_expr=cmd.filter_expr, - schema_id=cmd.schema_id, - record_srn=record_srn, - sort=cmd.sort, - order=cmd.order, - cursor=cmd.cursor, - limit=cmd.limit, - ) - return SearchFeaturesResult( - rows=[{"record_srn": str(r.record_srn), **r.data} for r in result.rows], - cursor=result.cursor, - has_more=result.has_more, - ) diff --git a/server/osa/domain/discovery/query/search_records.py b/server/osa/domain/discovery/query/search_records.py deleted file mode 100644 index 515d980..0000000 --- a/server/osa/domain/discovery/query/search_records.py +++ /dev/null @@ -1,59 +0,0 @@ -"""SearchRecords query — search and filter published records.""" - -from typing import Any - -from osa.domain.discovery.model.value import ( - FilterExpr, - RecordSearchResult, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.authorization.gate import public -from osa.domain.shared.model.srn import ConventionSRN, SchemaId -from osa.domain.shared.query import Query, QueryHandler, Result - - -class SearchRecords(Query): - filter_expr: FilterExpr | None = None - schema_id: SchemaId | None = None - convention_srn: ConventionSRN | None = None - q: str | None = None - sort: str = "published_at" - order: SortOrder = SortOrder.DESC - cursor: str | None = None - limit: int = 20 - - -class SearchRecordsResult(Result): - results: list[dict[str, Any]] - cursor: str | None - has_more: bool - - -class SearchRecordsHandler(QueryHandler[SearchRecords, SearchRecordsResult]): - __auth__ = public() - discovery_service: DiscoveryService - - async def run(self, cmd: SearchRecords) -> SearchRecordsResult: - result: RecordSearchResult = await self.discovery_service.search_records( - filter_expr=cmd.filter_expr, - schema_id=cmd.schema_id, - convention_srn=cmd.convention_srn, - q=cmd.q, - sort=cmd.sort, - order=cmd.order, - cursor=cmd.cursor, - limit=cmd.limit, - ) - return SearchRecordsResult( - results=[ - { - "srn": str(r.srn), - "published_at": r.published_at.isoformat(), - "metadata": r.metadata, - } - for r in result.results - ], - cursor=result.cursor, - has_more=result.has_more, - ) diff --git a/server/osa/domain/discovery/service/discovery.py b/server/osa/domain/discovery/service/discovery.py deleted file mode 100644 index cd8db52..0000000 --- a/server/osa/domain/discovery/service/discovery.py +++ /dev/null @@ -1,423 +0,0 @@ -"""DiscoveryService — read-only business logic for record and feature search. - -Validates the compound ``FilterExpr`` tree (bounds, field resolution, operator -compatibility) before handing it to the read store for SQL compilation. -""" - -from __future__ import annotations - -import logging -from typing import Any - -from osa.config import Config -from osa.domain.discovery.model.refs import ( - FeatureFieldRef, - MetadataFieldRef, -) -from osa.domain.discovery.model.value import ( - JSON_TYPE_OPERATORS, - VALID_OPERATORS, - And, - FeatureCatalog, - FeatureSearchResult, - FilterExpr, - FilterOperator, - Not, - Or, - Predicate, - RecordSearchResult, - SortOrder, - decode_cursor, - encode_cursor, -) -from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader -from osa.domain.discovery.port.read_store import DiscoveryReadStore -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import NotFoundError, ValidationError -from osa.domain.shared.model.srn import ConventionSRN, RecordSRN, SchemaId -from osa.domain.shared.service import Service - -logger = logging.getLogger(__name__) - - -class DiscoveryService(Service): - """Orchestrates validation and delegation for discovery queries.""" - - read_store: DiscoveryReadStore - field_reader: FieldDefinitionReader - config: Config - - async def search_records( - self, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - convention_srn: ConventionSRN | None, - q: str | None, - sort: str, - order: SortOrder, - cursor: str | None, - limit: int, - *, - allow_compound: bool = True, - ) -> RecordSearchResult: - """Validate the filter tree and delegate record search to the read store. - - ``allow_compound`` is a staged flag — US1 delivers AND-only + Predicate - support; US2 flips this to allow OR/NOT. Callers should leave it True - once US2 lands. - """ - if limit < 1 or limit > 100: - raise ValidationError("limit must be between 1 and 100", field="limit") - - if sort != "published_at" and schema_id is None: - raise ValidationError( - f"Sorting by '{sort}' requires the request to pin a 'schema' " - "('@'). Plain listings must sort by 'published_at'.", - field="sort", - code="schema_required_for_metadata_sort", - ) - - if q and schema_id is None: - raise ValidationError( - "Free-text search ('q') requires the request to pin a 'schema' " - "('@'). Without a schema, the server cannot resolve " - "which metadata fields are text-indexed.", - field="q", - code="schema_required_for_free_text_search", - ) - - schema_field_map: dict[str, FieldType] = {} - if schema_id is not None: - schema_field_map = await self.field_reader.get_fields_for_schema(schema_id) - if not schema_field_map: - raise NotFoundError( - f"Schema not found: {schema_id.render()}. " - "Pin an '@' that matches a registered schema." - ) - - if filter_expr is not None: - self._validate_tree(filter_expr, allow_compound=allow_compound) - await self._validate_refs(filter_expr, schema_id, schema_field_map) - - # Sort field validation (against pinned schema) - if sort != "published_at" and sort not in schema_field_map: - raise ValidationError( - f"Unknown sort field '{sort}': not defined in the pinned schema.", - field="sort", - code="unknown_sort_field", - ) - - decoded_cursor: dict[str, Any] | None = None - if cursor is not None: - try: - decoded_cursor = decode_cursor(cursor) - except ValueError as exc: - raise ValidationError(str(exc), field="cursor") from exc - - text_fields = [ - name for name, ft in schema_field_map.items() if ft in (FieldType.TEXT, FieldType.URL) - ] - if q and not text_fields: - raise ValidationError( - "Free-text search is unavailable: the pinned schema defines no text or URL fields.", - field="q", - code="no_text_fields_in_schema", - ) - - results = await self.read_store.search_records( - filter_expr=filter_expr, - schema_id=schema_id, - convention_srn=convention_srn, - text_fields=text_fields, - q=q, - sort=sort, - order=order, - cursor=decoded_cursor, - limit=limit + 1, - field_types=schema_field_map, - ) - - has_more = len(results) > limit - results = results[:limit] - next_cursor = None - if has_more and results: - last = results[-1] - if sort == "published_at": - sort_val = last.published_at.isoformat() - else: - sort_val = last.metadata.get(sort) - next_cursor = encode_cursor(sort_val, str(last.srn)) - - return RecordSearchResult( - results=results, - cursor=next_cursor, - has_more=has_more, - ) - - async def get_feature_catalog(self) -> FeatureCatalog: - entries = await self.read_store.get_feature_catalog() - return FeatureCatalog(tables=entries) - - async def search_features( - self, - hook_name: str, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - record_srn: RecordSRN | None, - sort: str, - order: SortOrder, - cursor: str | None, - limit: int, - *, - allow_compound: bool = True, - ) -> FeatureSearchResult: - if limit < 1 or limit > 100: - raise ValidationError("limit must be between 1 and 100", field="limit") - - entry = await self.read_store.get_feature_table_schema(hook_name) - if entry is None: - raise NotFoundError(f"Feature table not found: {hook_name}") - - col_map: dict[str, str] = {col.name: col.type for col in entry.columns} - col_map["record_srn"] = "string" - - schema_field_map: dict[str, FieldType] = {} - if schema_id is not None: - schema_field_map = await self.field_reader.get_fields_for_schema(schema_id) - if not schema_field_map: - raise NotFoundError( - f"Schema not found: {schema_id.render()}. " - "Pin an '@' that matches a registered schema." - ) - - if filter_expr is not None: - self._validate_tree(filter_expr, allow_compound=allow_compound) - self._validate_feature_refs( - filter_expr, - this_hook=hook_name, - feature_col_map=col_map, - schema_field_map=schema_field_map, - ) - - if sort != "id" and sort not in col_map: - raise ValidationError( - f"Unknown sort column '{sort}' in feature table '{hook_name}'", - field="sort", - ) - - try: - decoded_cursor = decode_cursor(cursor) if cursor else None - except ValueError as exc: - raise ValidationError(str(exc), field="cursor") from exc - - rows = await self.read_store.search_features( - hook_name=hook_name, - filter_expr=filter_expr, - schema_id=schema_id, - record_srn=record_srn, - sort=sort, - order=order, - cursor=decoded_cursor, - limit=limit + 1, - ) - - has_more = len(rows) > limit - rows = rows[:limit] - next_cursor = None - if has_more and rows: - last = rows[-1] - if sort == "id": - sort_val = last.row_id - else: - sort_val = last.data.get(sort) - next_cursor = encode_cursor(sort_val, last.row_id) - - return FeatureSearchResult(rows=rows, cursor=next_cursor, has_more=has_more) - - # ------------------------- internal helpers ------------------------- - - def _validate_tree(self, expr: FilterExpr, *, allow_compound: bool) -> None: - """Enforce tree bounds (depth, predicate count, joins) + compound gating.""" - depth = _tree_depth(expr) - predicates = list(_iter_predicates(expr)) - - if depth > self.config.discovery_max_filter_depth: - raise ValidationError( - f"Filter tree depth {depth} exceeds configured maximum " - f"{self.config.discovery_max_filter_depth} (OSA_DISCOVERY_MAX_FILTER_DEPTH).", - field="filter", - code="filter_depth_exceeded", - ) - if len(predicates) > self.config.discovery_max_predicates: - raise ValidationError( - f"Filter tree has {len(predicates)} predicate leaves, exceeds " - f"configured maximum {self.config.discovery_max_predicates} " - "(OSA_DISCOVERY_MAX_PREDICATES).", - field="filter", - code="filter_predicates_exceeded", - ) - - distinct_hooks: set[str] = set() - for p in predicates: - if isinstance(p.field, FeatureFieldRef): - distinct_hooks.add(p.field.hook) - if len(distinct_hooks) > self.config.discovery_max_cross_domain_joins: - raise ValidationError( - f"Filter tree joins {len(distinct_hooks)} distinct feature hooks, " - f"exceeds configured maximum " - f"{self.config.discovery_max_cross_domain_joins} " - "(OSA_DISCOVERY_MAX_CROSS_DOMAIN_JOINS).", - field="filter", - code="filter_joins_exceeded", - ) - - if not allow_compound: - for node in _iter_nodes(expr): - if isinstance(node, (Or, Not)): - raise ValidationError( - "Compound OR/NOT filters are not enabled in this build.", - field="filter", - code="compound_disabled", - ) - - async def _validate_refs( - self, - expr: FilterExpr, - schema_id: SchemaId | None, - field_map: dict[str, FieldType], - ) -> None: - """Resolve each predicate's field and check operator compatibility.""" - feature_catalog: dict[str, dict[str, str]] | None = None - for p in _iter_predicates(expr): - if isinstance(p.field, MetadataFieldRef): - if schema_id is None: - raise ValidationError( - f"Metadata predicate on {p.field.dotted()!r} requires " - "the request to pin a 'schema' ('@'). " - "Unscoped metadata filtering is not supported — the typed " - "metadata table is the only filter path.", - field=p.field.dotted(), - code="schema_required_for_metadata_query", - ) - field_name = p.field.field - if field_name not in field_map: - raise ValidationError( - f"Unknown metadata field '{field_name}' for the pinned schema.", - field=p.field.dotted(), - code="unknown_field", - ) - self._check_operator_for_field_type( - p, field_type=field_map[field_name], path=p.field.dotted() - ) - elif isinstance(p.field, FeatureFieldRef): - if feature_catalog is None: - feature_catalog = await self._load_feature_catalog() - cols = feature_catalog.get(p.field.hook) - if cols is None: - raise ValidationError( - f"Unknown feature hook '{p.field.hook}'.", - field=p.field.dotted(), - code="unknown_hook", - ) - if p.field.column not in cols: - raise ValidationError( - f"Unknown feature column '{p.field.column}' on hook '{p.field.hook}'.", - field=p.field.dotted(), - code="unknown_field", - ) - json_type = cols[p.field.column] - self._check_operator_for_json_type(p, json_type=json_type, path=p.field.dotted()) - - def _validate_feature_refs( - self, - expr: FilterExpr, - *, - this_hook: str, - feature_col_map: dict[str, str], - schema_field_map: dict[str, FieldType], - ) -> None: - """Variant of ref validation for feature search — local hook columns by default.""" - for p in _iter_predicates(expr): - if isinstance(p.field, MetadataFieldRef): - if p.field.field not in schema_field_map: - raise ValidationError( - f"Unknown metadata field '{p.field.field}' for the provided schema.", - field=p.field.dotted(), - code="unknown_field", - ) - self._check_operator_for_field_type( - p, field_type=schema_field_map[p.field.field], path=p.field.dotted() - ) - elif isinstance(p.field, FeatureFieldRef): - if p.field.hook != this_hook: - # Cross-hook joins handled by US3 — accepted here, resolved in adapter. - continue - if p.field.column not in feature_col_map: - raise ValidationError( - f"Unknown feature column '{p.field.column}' on hook '{this_hook}'.", - field=p.field.dotted(), - code="unknown_field", - ) - self._check_operator_for_json_type( - p, json_type=feature_col_map[p.field.column], path=p.field.dotted() - ) - - async def _load_feature_catalog(self) -> dict[str, dict[str, str]]: - """Build hook_name → {column_name → json_type} map from the catalog.""" - catalog = await self.read_store.get_feature_catalog() - return {entry.hook_name: {col.name: col.type for col in entry.columns} for entry in catalog} - - @staticmethod - def _check_operator_for_field_type( - predicate: Predicate, *, field_type: FieldType, path: str - ) -> None: - valid = VALID_OPERATORS.get(field_type, set()) - if predicate.op not in valid: - raise ValidationError( - f"Operator '{predicate.op}' is not valid for field '{path}' " - f"(type '{field_type}'). Valid: {sorted(valid)}.", - field=path, - code="operator_not_valid_for_type", - ) - - @staticmethod - def _check_operator_for_json_type(predicate: Predicate, *, json_type: str, path: str) -> None: - valid = JSON_TYPE_OPERATORS.get(json_type, {FilterOperator.EQ}) - if predicate.op not in valid: - raise ValidationError( - f"Operator '{predicate.op}' is not valid for column '{path}' " - f"(json type '{json_type}'). Valid: {sorted(valid)}.", - field=path, - code="operator_not_valid_for_type", - ) - - -def _tree_depth(expr: FilterExpr) -> int: - if isinstance(expr, Predicate): - return 1 - if isinstance(expr, Not): - return 1 + _tree_depth(expr.operand) - if isinstance(expr, (And, Or)): - return 1 + max(_tree_depth(op) for op in expr.operands) - return 1 - - -def _iter_predicates(expr: FilterExpr): - if isinstance(expr, Predicate): - yield expr - return - if isinstance(expr, Not): - yield from _iter_predicates(expr.operand) - return - if isinstance(expr, (And, Or)): - for op in expr.operands: - yield from _iter_predicates(op) - - -def _iter_nodes(expr: FilterExpr): - yield expr - if isinstance(expr, Not): - yield from _iter_nodes(expr.operand) - elif isinstance(expr, (And, Or)): - for op in expr.operands: - yield from _iter_nodes(op) diff --git a/server/osa/domain/discovery/util/di/__init__.py b/server/osa/domain/discovery/util/di/__init__.py deleted file mode 100644 index 8672799..0000000 --- a/server/osa/domain/discovery/util/di/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from osa.domain.discovery.util.di.provider import DiscoveryProvider - -__all__ = ["DiscoveryProvider"] diff --git a/server/osa/domain/discovery/util/di/provider.py b/server/osa/domain/discovery/util/di/provider.py deleted file mode 100644 index 9715b8c..0000000 --- a/server/osa/domain/discovery/util/di/provider.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Dishka DI provider for the discovery domain.""" - -from dishka import provide - -from osa.config import Config -from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader -from osa.domain.discovery.port.read_store import DiscoveryReadStore -from osa.domain.discovery.query.get_feature_catalog import GetFeatureCatalogHandler -from osa.domain.discovery.query.search_features import SearchFeaturesHandler -from osa.domain.discovery.query.search_records import SearchRecordsHandler -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.util.di.base import Provider -from osa.util.di.scope import Scope - - -class DiscoveryProvider(Provider): - @provide(scope=Scope.UOW) - def get_discovery_service( - self, - read_store: DiscoveryReadStore, - field_reader: FieldDefinitionReader, - config: Config, - ) -> DiscoveryService: - return DiscoveryService( - read_store=read_store, - field_reader=field_reader, - config=config, - ) - - # Query Handlers - search_records_handler = provide(SearchRecordsHandler, scope=Scope.UOW) - get_feature_catalog_handler = provide(GetFeatureCatalogHandler, scope=Scope.UOW) - search_features_handler = provide(SearchFeaturesHandler, scope=Scope.UOW) diff --git a/server/osa/domain/index/__init__.py b/server/osa/domain/index/__init__.py deleted file mode 100644 index 96aacc8..0000000 --- a/server/osa/domain/index/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Index domain - configuration and orchestration for storage backends.""" diff --git a/server/osa/domain/index/event/__init__.py b/server/osa/domain/index/event/__init__.py deleted file mode 100644 index b350a4e..0000000 --- a/server/osa/domain/index/event/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Index domain events.""" - -from osa.domain.index.event.index_record import IndexRecord - -__all__ = ["IndexRecord"] diff --git a/server/osa/domain/index/event/index_record.py b/server/osa/domain/index/event/index_record.py deleted file mode 100644 index 7f438e0..0000000 --- a/server/osa/domain/index/event/index_record.py +++ /dev/null @@ -1,26 +0,0 @@ -"""IndexRecord event - per-backend indexing request for a single record.""" - -from typing import Any - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import RecordSRN - - -class IndexRecord(Event): - """Request to index a single record into a specific backend. - - This event is created by FanOutToIndexBackends when a RecordPublished - event is received. Each backend gets its own IndexRecord event, - enabling independent retry and failure isolation. - - Attributes: - id: Unique event identifier (inherited from Event). - backend_name: Target backend name (e.g., "vector", "keyword"). - record_srn: Structured Resource Name of the record. - metadata: Record metadata to index. - """ - - id: EventId - backend_name: str - record_srn: RecordSRN - metadata: dict[str, Any] diff --git a/server/osa/domain/index/handler/__init__.py b/server/osa/domain/index/handler/__init__.py deleted file mode 100644 index 75cfab4..0000000 --- a/server/osa/domain/index/handler/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Index domain event handlers.""" - -from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends -from osa.domain.index.handler.keyword_index_handler import KeywordIndexHandler -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler - -__all__ = ["FanOutToIndexBackends", "KeywordIndexHandler", "VectorIndexHandler"] diff --git a/server/osa/domain/index/handler/fanout_to_index_backends.py b/server/osa/domain/index/handler/fanout_to_index_backends.py deleted file mode 100644 index 4cb265d..0000000 --- a/server/osa/domain/index/handler/fanout_to_index_backends.py +++ /dev/null @@ -1,47 +0,0 @@ -"""FanOutToIndexBackends - creates per-backend IndexRecord events from RecordPublished.""" - -import logging -from typing import ClassVar -from uuid import uuid4 - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventHandler, EventId -from osa.domain.shared.outbox import Outbox - -logger = logging.getLogger(__name__) - - -class FanOutToIndexBackends(EventHandler[RecordPublished]): - """Creates per-backend IndexRecord events from RecordPublished. - - When a record is published, this handler creates one IndexRecord event - per registered backend. Each IndexRecord is stored in the outbox, - enabling independent retry and failure isolation per backend. - - This replaces the previous pattern where a single RecordPublished event - triggered immediate indexing to all backends in a single transaction. - - Batch processing is used for efficiency when multiple records are - published in quick succession. - """ - - __batch_size__: ClassVar[int] = 10 - - indexes: IndexRegistry - outbox: Outbox - - async def handle(self, event: RecordPublished) -> None: - """Create IndexRecord events for each registered backend.""" - backend_names = list(self.indexes) - logger.debug(f"FanOut: {event.record_srn} -> {len(backend_names)} backends") - - for backend_name in backend_names: - index_event = IndexRecord( - id=EventId(uuid4()), - backend_name=backend_name, - record_srn=event.record_srn, - metadata=event.metadata, - ) - await self.outbox.append(index_event) diff --git a/server/osa/domain/index/handler/keyword_index_handler.py b/server/osa/domain/index/handler/keyword_index_handler.py deleted file mode 100644 index a6f5791..0000000 --- a/server/osa/domain/index/handler/keyword_index_handler.py +++ /dev/null @@ -1,38 +0,0 @@ -"""KeywordIndexHandler - processes IndexRecord events for keyword backends.""" - -import logging -from typing import ClassVar - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.error import SkippedEvents -from osa.domain.shared.event import EventHandler - -logger = logging.getLogger(__name__) - - -class KeywordIndexHandler(EventHandler[IndexRecord]): - """Processes IndexRecord events for the keyword backend. - - Processes events immediately (batch_size=1) since keyword indexing - doesn't benefit from batching. - """ - - __batch_size__: ClassVar[int] = 1 - - indexes: IndexRegistry - - async def handle(self, event: IndexRecord) -> None: - """Process a single IndexRecord event.""" - backend = self.indexes.get("keyword") - if backend is None: - raise SkippedEvents( - event_ids=[event.id], - reason="Keyword backend not available", - ) - - record = (str(event.record_srn), event.metadata) - - await backend.ingest_batch([record]) - - logger.debug(f"KeywordIndexHandler: indexed event {event.id}") diff --git a/server/osa/domain/index/handler/vector_index_handler.py b/server/osa/domain/index/handler/vector_index_handler.py deleted file mode 100644 index aa3d293..0000000 --- a/server/osa/domain/index/handler/vector_index_handler.py +++ /dev/null @@ -1,47 +0,0 @@ -"""VectorIndexHandler - processes IndexRecord events for vector backends.""" - -import logging -from typing import ClassVar - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.error import SkippedEvents -from osa.domain.shared.event import EventHandler - -logger = logging.getLogger(__name__) - - -class VectorIndexHandler(EventHandler[IndexRecord]): - """Processes IndexRecord events for the vector backend. - - Processes events in batches for efficient embedding generation. - """ - - __batch_size__: ClassVar[int] = 100 - __batch_timeout__: ClassVar[float] = 5.0 - - indexes: IndexRegistry - - async def handle_batch(self, events: list[IndexRecord]) -> None: - """Process a batch of IndexRecord events. - - Converts events to records and calls ingest_batch on the backend. - """ - if not events: - return - - backend = self.indexes.get("vector") - if backend is None: - raise SkippedEvents( - event_ids=[e.id for e in events], - reason="Vector backend not available", - ) - - # Prepare records for batch ingestion - records = [(str(e.record_srn), e.metadata) for e in events] - - logger.debug(f"VectorIndexHandler: ingesting {len(records)} records to backend") - - await backend.ingest_batch(records) - - logger.debug(f"VectorIndexHandler: ingested {len(events)} records") diff --git a/server/osa/domain/index/model/__init__.py b/server/osa/domain/index/model/__init__.py deleted file mode 100644 index 4f64a80..0000000 --- a/server/osa/domain/index/model/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Index domain models.""" - -from osa.domain.index.model.registry import IndexRegistry - -__all__ = ["IndexRegistry"] diff --git a/server/osa/domain/index/model/registry.py b/server/osa/domain/index/model/registry.py deleted file mode 100644 index 9bb01bf..0000000 --- a/server/osa/domain/index/model/registry.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Index registry - typed container for available storage backends.""" - -from collections.abc import Iterator - -from osa.sdk.index.backend import StorageBackend - - -class IndexRegistry: - """Registry of available index backends.""" - - def __init__(self, backends: dict[str, StorageBackend]) -> None: - self._backends = backends - - def get(self, name: str) -> StorageBackend | None: - """Get a backend by name.""" - return self._backends.get(name) - - def __contains__(self, name: str) -> bool: - return name in self._backends - - def __iter__(self) -> Iterator[str]: - return iter(self._backends) - - def __len__(self) -> int: - return len(self._backends) - - def names(self) -> list[str]: - """List all available index names.""" - return list(self._backends.keys()) - - def items(self) -> Iterator[tuple[str, StorageBackend]]: - """Iterate over (name, backend) pairs.""" - return iter(self._backends.items()) diff --git a/server/osa/domain/index/service/__init__.py b/server/osa/domain/index/service/__init__.py deleted file mode 100644 index 852bec9..0000000 --- a/server/osa/domain/index/service/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Index service module.""" - -from osa.domain.index.service.index import IndexService - -__all__ = ["IndexService"] diff --git a/server/osa/domain/index/service/index.py b/server/osa/domain/index/service/index.py deleted file mode 100644 index b4a485b..0000000 --- a/server/osa/domain/index/service/index.py +++ /dev/null @@ -1,48 +0,0 @@ -"""IndexService - orchestrates indexing of records into storage backends.""" - -import logging - -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.service import Service - -logger = logging.getLogger(__name__) - - -class IndexService(Service): - """Service for index-related operations. - - Note: Direct indexing (index_record, flush_all) has been replaced by the - event-driven approach using FanOutToIndexBackends and IndexRecordBatch - listeners. This service is retained for query operations and future - index management commands. - """ - - indexes: IndexRegistry - - async def get_count(self, backend_name: str) -> int | None: - """Get the document count for a specific backend. - - Args: - backend_name: Name of the backend to query. - - Returns: - Document count, or None if backend not found. - """ - backend = self.indexes.get(backend_name) - if backend is None: - return None - return await backend.count() - - async def check_health(self, backend_name: str) -> bool | None: - """Check health of a specific backend. - - Args: - backend_name: Name of the backend to check. - - Returns: - True if healthy, False if unhealthy, None if backend not found. - """ - backend = self.indexes.get(backend_name) - if backend is None: - return None - return await backend.health() diff --git a/server/osa/domain/search/__init__.py b/server/osa/domain/search/__init__.py deleted file mode 100644 index cec05f3..0000000 --- a/server/osa/domain/search/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Search domain - provides search capabilities over indexed records.""" diff --git a/server/osa/domain/search/command/__init__.py b/server/osa/domain/search/command/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/event/__init__.py b/server/osa/domain/search/event/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/model/__init__.py b/server/osa/domain/search/model/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/port/__init__.py b/server/osa/domain/search/port/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/query/__init__.py b/server/osa/domain/search/query/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/search/service/__init__.py b/server/osa/domain/search/service/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/osa/domain/semantics/model/schema.py b/server/osa/domain/semantics/model/schema.py index 20af362..a139eb5 100644 --- a/server/osa/domain/semantics/model/schema.py +++ b/server/osa/domain/semantics/model/schema.py @@ -1,8 +1,9 @@ from datetime import datetime from osa.domain.semantics.model.value import FieldDefinition -from osa.domain.shared.error import ValidationError +from osa.domain.shared.error import ReservedNameError, ValidationError from osa.domain.shared.model.aggregate import Aggregate +from osa.domain.shared.model.reserved import RESERVED_NAMES from osa.domain.shared.model.srn import SchemaId @@ -15,6 +16,9 @@ class Schema(Aggregate): created_at: datetime def model_post_init(self, __context: object) -> None: + if self.id.id.root in RESERVED_NAMES: + raise ReservedNameError(self.id.id.root, "schema") + if len(self.fields) < 1: raise ValidationError("Schema must have at least one field") diff --git a/server/osa/domain/shared/error.py b/server/osa/domain/shared/error.py index f03951b..1ba718f 100644 --- a/server/osa/domain/shared/error.py +++ b/server/osa/domain/shared/error.py @@ -52,6 +52,26 @@ class ConflictError(DomainError): """Resource already exists or version conflict.""" +class ReservedNameError(DomainError): + """A schema ID or hook/feature name collides with a reserved URL slot. + + Raised at aggregate construction (``Schema``, ``HookDefinition``) when a + name equals one of :data:`osa.domain.shared.model.reserved.RESERVED_NAMES`. + Surfaced as HTTP 400 with the structured ``code`` field. + """ + + def __init__(self, name: str, kind: str) -> None: + from osa.domain.shared.model.reserved import RESERVED_NAMES + + self.name = name + self.kind = kind + super().__init__( + f"The {kind} name '{name}' is reserved for a URL slot under /data/. " + f"Reserved names: {sorted(RESERVED_NAMES)}.", + code="reserved_name", + ) + + class AuthorizationError(DomainError): """User not authorized for this operation.""" diff --git a/server/osa/domain/shared/model/hook.py b/server/osa/domain/shared/model/hook.py index 3691dbc..8c08dee 100644 --- a/server/osa/domain/shared/model/hook.py +++ b/server/osa/domain/shared/model/hook.py @@ -127,6 +127,15 @@ class HookDefinition(ValueObject): runtime: Annotated[OciConfig, Field(discriminator="type")] feature: Annotated[TableFeatureSpec, Field(discriminator="kind")] + def model_post_init(self, __context: object) -> None: + # A hook name becomes a feature-table URL slot under /data/{schema}/. + # Reject names that would shadow a fixed slot (research §6). + from osa.domain.shared.error import ReservedNameError + from osa.domain.shared.model.reserved import RESERVED_NAMES + + if self.name in RESERVED_NAMES: + raise ReservedNameError(self.name, "hook") + def with_memory(self, memory: str) -> "HookDefinition": """Return a copy with a different memory limit.""" new_limits = self.runtime.limits.model_copy(update={"memory": memory}) diff --git a/server/osa/domain/shared/model/ids.py b/server/osa/domain/shared/model/ids.py new file mode 100644 index 0000000..f48c6c9 --- /dev/null +++ b/server/osa/domain/shared/model/ids.py @@ -0,0 +1,20 @@ +"""Central semantic ID types used across the ``/data/`` read surface. + +Per OSA's type-safety convention, semantic identifiers cross module +boundaries as ``NewType`` aliases rather than bare ``str``. ``HookName`` is +re-exported from :mod:`osa.domain.shared.model.hook` (its source of truth, +where the PG-identifier pattern validation lives) so callers have a single +import location for the IDs this surface deals in. +""" + +from __future__ import annotations + +from typing import NewType + +from osa.domain.shared.model.hook import HookName + +# Bare internal record identifier (UUIDv7 / ULID). Validation of the exact +# charset/length happens at the API boundary; internally it is opaque. +RecordId = NewType("RecordId", str) + +__all__ = ["RecordId", "HookName"] diff --git a/server/osa/domain/shared/model/reserved.py b/server/osa/domain/shared/model/reserved.py new file mode 100644 index 0000000..7281f8c --- /dev/null +++ b/server/osa/domain/shared/model/reserved.py @@ -0,0 +1,16 @@ +"""Reserved names that collide with fixed URL slots under ``/data/``. + +The unified ``/data/`` read surface exposes a small set of fixed path +segments (``/data/records/...``, and the deferred ``/data/datasets`` slot). +A schema ID or hook/feature name equal to one of these would shadow the +fixed route, so they are forbidden at aggregate construction time. + +This module is the single source of truth for the reserved set. Adding a new +reserved slot (e.g. ``events``) later is a one-line change here. +""" + +from __future__ import annotations + +# Lowercase ASCII; only ever grows over the project's lifetime (shrinking +# would silently re-enable a name that a consumer's URL may already depend on). +RESERVED_NAMES: frozenset[str] = frozenset({"records", "datasets"}) diff --git a/server/osa/domain/export/command/__init__.py b/server/osa/infrastructure/data/__init__.py similarity index 100% rename from server/osa/domain/export/command/__init__.py rename to server/osa/infrastructure/data/__init__.py diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py new file mode 100644 index 0000000..7a4786e --- /dev/null +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -0,0 +1,615 @@ +"""Postgres adapter for the ``DataReadStore`` port. + +The streaming primitive (:meth:`stream_rows`) builds the records / feature +SELECT, then iterates it through ``AsyncSession.stream()`` — a server-side +cursor (research §2). Rows are yielded one at a time as flattened column→value +mappings, so memory stays bounded regardless of result size. The ``async with`` +around the streaming result closes the cursor on client disconnect, returning +the connection to the pool. + +The records filter compilation reuses the proven approach from the discovery +adapter, ported onto the ``data`` domain's own filter model so the adapter has +no dependency on ``discovery`` once that domain is deleted (research §10). +""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator, Mapping +from datetime import date, datetime +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import ( + RowMapping, + String, + and_, + cast, + false, + func, + not_, + or_, + select, +) +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.data.model.catalog import ( + CatalogEntry, + NodeCatalog, + TableResourceSummary, +) +from osa.domain.data.model.filter import ( + And, + FeatureFieldRef, + FilterExpr, + FilterOperator, + MetadataFieldRef, + Not, + Or, + Predicate, +) +from osa.domain.data.model.manifest import ( + IMPLICIT_FEATURE_COLUMN_SPECS, + IMPLICIT_RECORD_COLUMN_SPECS, + ColumnSpec, + FieldSpec, + SchemaManifest, + TableResource, +) +from osa.domain.data.model.query_plan import QueryPlan, SortDirection, TableKind +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import NotFoundError, ValidationError +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, + data_columns, +) +from osa.infrastructure.persistence.keyset import KeysetPage, SortKey +from osa.infrastructure.persistence.metadata_table import ( + MetadataSchema, + build_metadata_table, +) +from osa.infrastructure.persistence.tables import ( + conventions_table, + feature_tables_table, + metadata_tables_table, + records_table, + schemas_table, +) + +logger = logging.getLogger(__name__) + +# A feature column's JSON-primitive type → the manifest's semantic FieldType. +_JSON_TYPE_TO_FIELD_TYPE: dict[str, FieldType] = { + "string": FieldType.TEXT, + "number": FieldType.NUMBER, + "integer": FieldType.NUMBER, + "boolean": FieldType.BOOLEAN, + "array": FieldType.TEXT, + "object": FieldType.TEXT, +} + +# All URL-exposed format suffixes (mirrors model.format.FORMATS suffixes). +_ALL_FORMATS = ["", "csv", "csv.gz"] + + +def _escape_like(value: str) -> str: + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + +class PostgresDataReadStore: + def __init__(self, session: AsyncSession, node_domain: Domain) -> None: + self.session = session + # Only the node's DNS domain is needed (to render SRNs in the catalog / + # manifest) — not the whole Config. + self.node_domain = node_domain + + # ------------------------------------------------------------------ # + # Streaming primitive + # ------------------------------------------------------------------ # + + async def stream_rows(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + if plan.table_kind == TableKind.RECORDS: + async for row in self._stream_records(plan): + yield row + else: # TableKind.FEATURE + async for row in self._stream_features(plan): + yield row + + async def _stream_records(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + catalog = await self._metadata_catalog_for(plan.schema_id) + if catalog is None: + raise NotFoundError( + f"Schema not found: {plan.schema_id.render()}. See /api/v1/data for the catalog." + ) + metadata_schema = MetadataSchema.model_validate(catalog["metadata_schema"]) + metadata_table = build_metadata_table(catalog["pg_table"], metadata_schema) + + t = records_table + conditions: list[Any] = [ + t.c.schema_id == plan.schema_id.id.root, + t.c.schema_version == plan.schema_id.version.root, + ] + if plan.filter is not None: + conditions.append(self._compile_filter(plan.filter, metadata_t=metadata_table)) + + order_keys, cursor_after = self._records_sort(plan, metadata_table) + if cursor_after is not None: + conditions.append(cursor_after) + + select_cols = [ + t.c.srn, + t.c.schema_id, + t.c.schema_version, + t.c.published_at, + ] + [metadata_table.c[c.name].label(c.name) for c in metadata_schema.columns] + stmt = ( + select(*select_cols) + .select_from(t.join(metadata_table, metadata_table.c.record_srn == t.c.srn)) + .where(and_(*conditions)) + .order_by(*order_keys) + ) + + col_names = [c.name for c in metadata_schema.columns] + # ``stream()`` opens a server-side cursor. The try/finally closes it on + # client disconnect (the generator is thrown a CancelledError), returning + # the connection to the pool (research §2). + result = await self.session.stream(stmt) + try: + async for row in result.mappings(): + yield self._records_row_to_mapping(row, col_names) + finally: + await result.close() + + def _records_row_to_mapping(self, row: RowMapping, col_names: list[str]) -> dict[str, Any]: + srn = RecordSRN.parse(row["srn"]) + summary = RecordSummary( + id=RecordId(srn.id.root), + srn=srn, + schema_id=SchemaId.parse(f"{row['schema_id']}@{row['schema_version']}"), + version=int(srn.version.root), + metadata={name: row[name] for name in col_names if name in row}, + created_at=row["published_at"], + ) + return summary.flatten() + + def _records_sort(self, plan: QueryPlan, metadata_table: Any) -> tuple[list[Any], Any | None]: + t = records_table + # First sort key drives the cursor value; ``id`` (srn) is the tiebreaker. + primary = plan.sort[0] + is_desc = primary.direction == SortDirection.DESC + if primary.column in ("created_at", "published_at"): + sort_expr = t.c.published_at + elif primary.column == "id": + sort_expr = t.c.srn + elif primary.column in metadata_table.c: + sort_expr = metadata_table.c[primary.column] + else: + raise ValidationError( + f"Unknown sort column '{primary.column}'.", field="sort", code="unknown_sort_field" + ) + page = KeysetPage( + [ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(t.c.srn, descending=is_desc), + ] + ) + cursor_after = None + if plan.pagination.cursor is not None: + from osa.domain.data.model.query_plan import decode_cursor + + decoded = decode_cursor(str(plan.pagination.cursor)) + sort_value = self._coerce_cursor_value(decoded["s"], primary.column) + cursor_after = page.after((sort_value, decoded["id"])) + return page.order_by(), cursor_after + + @staticmethod + def _coerce_cursor_value(value: Any, column: str) -> Any: + if column in ("created_at", "published_at") and isinstance(value, str): + return datetime.fromisoformat(value) + return value + + # ------------------------------------------------------------------ # + # Feature-table streaming (US5) + # ------------------------------------------------------------------ # + + async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + if plan.feature_name is None: # guarded by QueryPlan, narrowed for the type checker + raise ValidationError("feature_name is required for a FEATURE plan", field="feature") + ft, fschema = await self._resolve_feature_table(plan.schema_id, plan.feature_name) + + conditions: list[Any] = [] + if plan.filter is not None: + conditions.append( + self._compile_feature_filter(plan.filter, ft=ft, feature_name=plan.feature_name) + ) + + order_keys, cursor_after = self._features_sort(plan, ft, fschema) + if cursor_after is not None: + conditions.append(cursor_after) + + # Implicit columns (id, record_srn, created_at) precede the hook's + # declared data columns — this is the CSV header order. + select_cols = [ft.c.id, ft.c.record_srn, ft.c.created_at, *data_columns(ft)] + stmt = select(*select_cols).select_from(ft) + if conditions: + stmt = stmt.where(and_(*conditions)) + stmt = stmt.order_by(*order_keys) + + result = await self.session.stream(stmt) + try: + async for row in result.mappings(): + yield dict(row) + finally: + await result.close() + + async def _resolve_feature_table( + self, schema_id: SchemaId, feature_name: str + ) -> tuple[sa.Table, FeatureSchema]: + """Resolve a feature table that belongs to ``schema_id``. + + A feature (hook) belongs to a schema via the convention that registers + it. Streaming a feature not registered on the schema is a 404, not a + leak of another schema's table. + """ + for hook_name, fschema in await self._schema_feature_tables(schema_id): + if hook_name == feature_name: + return build_feature_table(hook_name, fschema), fschema + raise NotFoundError( + f"No feature table '{feature_name}' on schema {schema_id.render()}. " + f"See /api/v1/data/{schema_id.render()} for its table resources.", + code="feature_not_found", + ) + + async def _schema_hook_names(self, schema_id: SchemaId) -> set[str]: + """Hook names registered on the schema's conventions (the schema→feature link).""" + stmt = select(conventions_table.c.hooks).where( + conventions_table.c.schema_id == schema_id.id.root, + conventions_table.c.schema_version == schema_id.version.root, + ) + result = await self.session.execute(stmt) + names: set[str] = set() + for (hooks,) in result.all(): + for hook in hooks or []: + names.add(hook["name"]) + return names + + async def _schema_feature_tables(self, schema_id: SchemaId) -> list[tuple[str, FeatureSchema]]: + """(hook_name, FeatureSchema) for every materialized feature table on the schema.""" + hook_names = await self._schema_hook_names(schema_id) + if not hook_names: + return [] + stmt = select( + feature_tables_table.c.hook_name, feature_tables_table.c.feature_schema + ).where(feature_tables_table.c.hook_name.in_(hook_names)) + result = await self.session.execute(stmt) + return [ + (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) + for row in result.mappings() + ] + + async def _feature_count(self, ft: sa.Table) -> int: + return int((await self.session.execute(select(func.count()).select_from(ft))).scalar_one()) + + def _features_sort( + self, plan: QueryPlan, ft: sa.Table, fschema: FeatureSchema + ) -> tuple[list[Any], Any | None]: + primary = plan.sort[0] + is_desc = primary.direction == SortDirection.DESC + if primary.column == "id": + sort_expr = ft.c.id + elif primary.column in ft.c: + sort_expr = ft.c[primary.column] + else: + raise ValidationError( + f"Unknown sort column '{primary.column}'.", + field="sort", + code="unknown_sort_field", + ) + page = KeysetPage( + [ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(ft.c.id, descending=is_desc), + ] + ) + cursor_after = None + if plan.pagination.cursor is not None: + from osa.domain.data.model.query_plan import decode_cursor + + decoded = decode_cursor(str(plan.pagination.cursor)) + sort_value = self._coerce_feature_cursor_value(decoded["s"], primary.column, fschema) + cursor_after = page.after((sort_value, int(decoded["id"]))) + return page.order_by(), cursor_after + + @staticmethod + def _coerce_feature_cursor_value(value: Any, column: str, fschema: FeatureSchema) -> Any: + if column == "id": + return None if value is None else int(value) + col_def = next((c for c in fschema.columns if c.name == column), None) + if col_def is not None and col_def.json_type == "string" and isinstance(value, str): + if col_def.format == "date-time": + return datetime.fromisoformat(value) + if col_def.format == "date": + return date.fromisoformat(value) + return value + + def _compile_feature_filter(self, expr: FilterExpr, *, ft: sa.Table, feature_name: str) -> Any: + if isinstance(expr, Predicate): + return self._compile_feature_predicate(expr, ft=ft, feature_name=feature_name) + if isinstance(expr, And): + return and_( + *[ + self._compile_feature_filter(op, ft=ft, feature_name=feature_name) + for op in expr.operands + ] + ) + if isinstance(expr, Or): + return or_( + *[ + self._compile_feature_filter(op, ft=ft, feature_name=feature_name) + for op in expr.operands + ] + ) + if isinstance(expr, Not): + inner = self._compile_feature_filter(expr.operand, ft=ft, feature_name=feature_name) + return not_(func.coalesce(inner, false())) + raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") + + def _compile_feature_predicate( + self, predicate: Predicate, *, ft: sa.Table, feature_name: str + ) -> Any: + if isinstance(predicate.field, FeatureFieldRef): + if predicate.field.hook != feature_name: + raise ValidationError( + f"Cross-feature predicates are not supported on the " + f"'{feature_name}' stream (referenced '{predicate.field.hook}').", + field=predicate.field.dotted(), + code="cross_feature_predicate_unsupported", + ) + if predicate.field.column not in ft.c: + raise ValidationError( + f"Unknown feature column '{predicate.field.column}'.", + field=predicate.field.dotted(), + code="unknown_feature_column", + ) + return _apply_scalar_op(ft.c[predicate.field.column], predicate.op, predicate.value) + if isinstance(predicate.field, MetadataFieldRef): + raise ValidationError( + "Metadata-field predicates are not supported on a feature stream.", + field=predicate.field.dotted(), + code="metadata_predicate_unsupported", + ) + raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") + + # ------------------------------------------------------------------ # + # Single record by ID + # ------------------------------------------------------------------ # + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: + # The records PK is the SRN ``urn:osa:{domain}:rec:{id}@{version}``. + # Match the id segment; resolve version (pin or latest published). + pattern = f"urn:osa:%:rec:{_escape_like(str(id))}@%" + t = records_table + stmt = ( + select(t.c.srn, t.c.schema_id, t.c.schema_version, t.c.published_at, t.c.metadata) + .where(t.c.srn.like(pattern, escape="\\")) + .order_by(t.c.published_at.desc()) + ) + result = await self.session.execute(stmt) + rows = result.mappings().all() + if not rows: + return None + + chosen = None + for row in rows: + srn = RecordSRN.parse(row["srn"]) + if srn.id.root != str(id): + continue + if version is None: + chosen = (srn, row) + break + if int(srn.version.root) == version: + chosen = (srn, row) + break + if chosen is None: + return None + srn, row = chosen + return RecordSummary( + id=RecordId(srn.id.root), + srn=srn, + schema_id=SchemaId.parse(f"{row['schema_id']}@{row['schema_version']}"), + version=int(srn.version.root), + metadata=row["metadata"] or {}, + created_at=row["published_at"], + ) + + # ------------------------------------------------------------------ # + # Catalog & manifest + # ------------------------------------------------------------------ # + + async def get_node_catalog(self) -> NodeCatalog: + stmt = select(schemas_table.c.id, schemas_table.c.version) + result = await self.session.execute(stmt) + schema_rows = [(row["id"], row["version"]) for row in result.mappings()] + entries: list[CatalogEntry] = [] + for short_id, version in schema_rows: + schema_id = SchemaId.parse(f"{short_id}@{version}") + resources = [TableResourceSummary(name="records", kind=TableKind.RECORDS)] + for hook_name, _ in await self._schema_feature_tables(schema_id): + resources.append(TableResourceSummary(name=hook_name, kind=TableKind.FEATURE)) + entries.append( + CatalogEntry( + id=short_id, + version=version, + srn=schema_id.to_srn(self.node_domain).render(), + table_resources=resources, + ) + ) + return NodeCatalog(node_domain=self.node_domain.root, schemas=entries) + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | None: + stmt = select(schemas_table.c.fields).where( + schemas_table.c.id == schema_id.id.root, + schemas_table.c.version == schema_id.version.root, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + + field_specs: list[FieldSpec] = [] + column_specs: list[ColumnSpec] = [] + for f in row["fields"]: + ftype = FieldType(f["type"]) + field_specs.append( + FieldSpec( + name=f["name"], + type=ftype, + ontology_id=f.get("ontology_id"), + ontology_version=f.get("ontology_version"), + ) + ) + column_specs.append(ColumnSpec(name=f["name"], type=ftype)) + + record_count = await self._records_count(schema_id) + records_resource = TableResource( + name="records", + kind=TableKind.RECORDS, + # Implicit columns (id, srn, schema_id, version, created_at) precede + # the schema's declared metadata fields — this is the CSV header order. + columns=[*IMPLICIT_RECORD_COLUMN_SPECS, *column_specs], + row_count=record_count, + formats=list(_ALL_FORMATS), + ) + feature_resources = await self._feature_resources(schema_id) + return SchemaManifest( + id=schema_id.id.root, + version=schema_id.version.root, + srn=schema_id.to_srn(self.node_domain).render(), + fields=field_specs, + table_resources=[records_resource, *feature_resources], + ) + + async def _feature_resources(self, schema_id: SchemaId) -> list[TableResource]: + """Build a TableResource for each feature table registered on the schema.""" + resources: list[TableResource] = [] + for hook_name, fschema in await self._schema_feature_tables(schema_id): + ft = build_feature_table(hook_name, fschema) + resources.append( + TableResource( + name=hook_name, + kind=TableKind.FEATURE, + # Implicit columns (id, record_srn, created_at) precede the + # hook's declared data columns — this is the CSV header order. + columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *_feature_column_specs(fschema)], + row_count=await self._feature_count(ft), + formats=list(_ALL_FORMATS), + ) + ) + return resources + + async def get_latest_schema_id(self, schema_short_id: str) -> SchemaId | None: + stmt = select(schemas_table.c.version).where(schemas_table.c.id == schema_short_id) + result = await self.session.execute(stmt) + versions = [row[0] for row in result.all()] + if not versions: + return None + # Pick the highest SemVer (string sort is wrong for e.g. 1.10.0 vs 1.9.0). + latest = max(versions, key=lambda v: tuple(int(p) for p in v.split("-")[0].split("."))) + return SchemaId.parse(f"{schema_short_id}@{latest}") + + async def _records_count(self, schema_id: SchemaId) -> int: + t = records_table + stmt = ( + select(func.count()) + .select_from(t) + .where( + t.c.schema_id == schema_id.id.root, + t.c.schema_version == schema_id.version.root, + ) + ) + return int((await self.session.execute(stmt)).scalar_one()) + + # ------------------------------------------------------------------ # + # Helpers (ported from the discovery adapter, records-only) + # ------------------------------------------------------------------ # + + async def _metadata_catalog_for(self, schema_id: SchemaId) -> dict[str, Any] | None: + stmt = select(metadata_tables_table).where( + metadata_tables_table.c.schema_id == schema_id.id.root, + metadata_tables_table.c.schema_major == schema_id.major, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return dict(row) if row is not None else None + + def _compile_filter(self, expr: FilterExpr, *, metadata_t: Any) -> Any: + if isinstance(expr, Predicate): + return self._compile_predicate(expr, metadata_t=metadata_t) + if isinstance(expr, And): + return and_(*[self._compile_filter(op, metadata_t=metadata_t) for op in expr.operands]) + if isinstance(expr, Or): + return or_(*[self._compile_filter(op, metadata_t=metadata_t) for op in expr.operands]) + if isinstance(expr, Not): + inner = self._compile_filter(expr.operand, metadata_t=metadata_t) + # NULL → FALSE before negating so records with NULL metadata survive NOT. + return not_(func.coalesce(inner, false())) + raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") + + def _compile_predicate(self, predicate: Predicate, *, metadata_t: Any) -> Any: + if isinstance(predicate.field, MetadataFieldRef): + if predicate.field.field not in metadata_t.c: + raise ValidationError( + f"Unknown metadata field '{predicate.field.field}'. " + "Filterable fields are listed in the schema manifest.", + field=predicate.field.dotted(), + code="unknown_metadata_field", + ) + col = metadata_t.c[predicate.field.field] + return _apply_scalar_op(col, predicate.op, predicate.value) + if isinstance(predicate.field, FeatureFieldRef): + # Feature predicates in the records stream are a US5 extension. + raise ValidationError( + "Feature-field predicates are not yet supported on the records stream.", + field=predicate.field.dotted(), + code="feature_predicate_unsupported", + ) + raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") + + +def _feature_column_specs(fschema: FeatureSchema) -> list[ColumnSpec]: + """Map a feature table's declared columns to manifest ColumnSpecs.""" + return [ + ColumnSpec(name=c.name, type=_JSON_TYPE_TO_FIELD_TYPE[c.json_type]) for c in fschema.columns + ] + + +def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: + if op == FilterOperator.EQ: + return col == value + if op == FilterOperator.NEQ: + return or_(col != value, col.is_(None)) + if op == FilterOperator.GT: + return col > value + if op == FilterOperator.GTE: + return col >= value + if op == FilterOperator.LT: + return col < value + if op == FilterOperator.LTE: + return col <= value + if op == FilterOperator.IN: + if not isinstance(value, list): + raise ValidationError( + "Operator 'in' requires a list value.", field=col.key, code="invalid_value_for_op" + ) + return col.in_(value) + if op == FilterOperator.CONTAINS: + return cast(col, String).ilike(f"%{_escape_like(str(value))}%", escape="\\") + if op == FilterOperator.IS_NULL: + return col.is_(None) + raise ValidationError( + f"Unsupported operator: {op}", field="filter", code="unsupported_operator" + ) diff --git a/server/osa/infrastructure/index/__init__.py b/server/osa/infrastructure/index/__init__.py deleted file mode 100644 index 70e4b1a..0000000 --- a/server/osa/infrastructure/index/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Index infrastructure - storage backend implementations.""" diff --git a/server/osa/infrastructure/index/di.py b/server/osa/infrastructure/index/di.py deleted file mode 100644 index 787426d..0000000 --- a/server/osa/infrastructure/index/di.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Dependency injection provider for index backends.""" - -from dishka import Provider, provide - -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.index.service import IndexService -from osa.util.di.scope import Scope - - -class IndexProvider(Provider): - """Provides configured index backends.""" - - @provide(scope=Scope.APP) - def get_backends(self) -> IndexRegistry: - """Build empty index registry (indexes are dormant).""" - return IndexRegistry({}) - - @provide(scope=Scope.UOW) - def get_index_service(self, indexes: IndexRegistry) -> IndexService: - """Provide IndexService for UOW scope.""" - return IndexService(indexes=indexes) diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py deleted file mode 100644 index be9d067..0000000 --- a/server/osa/infrastructure/persistence/adapter/discovery.py +++ /dev/null @@ -1,701 +0,0 @@ -"""Infrastructure adapters for the discovery domain — read-only SQL queries.""" - -from __future__ import annotations - -import logging -from collections.abc import Callable -from datetime import date, datetime -from typing import Any - -from sqlalchemy import ( - Date, - Float, - String, - and_, - cast, - false, - func, - literal, - not_, - or_, - select, - true, - union_all, -) -from sqlalchemy.ext.asyncio import AsyncSession - -from osa.domain.discovery.model.refs import FeatureFieldRef, MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - ColumnInfo, - FeatureCatalogEntry, - FeatureRow, - FilterExpr, - FilterOperator, - Not, - Or, - Predicate, - RecordSummary, - SortOrder, -) -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.hook import ColumnDef -from osa.domain.shared.model.srn import ConventionSRN, RecordSRN, SchemaId -from osa.infrastructure.persistence.feature_table import ( - FeatureSchema, - build_feature_table, - data_columns, -) -from osa.infrastructure.persistence.keyset import KeysetPage, SortKey -from osa.infrastructure.persistence.metadata_table import ( - MetadataSchema, - build_metadata_table, -) -from osa.infrastructure.persistence.tables import ( - feature_tables_table, - metadata_tables_table, - records_table, - schemas_table, -) - -logger = logging.getLogger(__name__) - - -def _escape_like(value: str) -> str: - """Escape LIKE metacharacters so user input is matched literally.""" - return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - - -# Cursor-value coercers — cursor payloads round-trip through base64 JSON as -# plain strings/numbers, but keyset predicates compare against typed columns. -# Without this, ``published_at < 'iso-string'::VARCHAR`` fails on Postgres. - -CursorCoercer = Callable[[Any], Any] - - -def _coerce_identity(value: Any) -> Any: - return value - - -def _coerce_datetime(value: Any) -> Any: - if isinstance(value, str): - return datetime.fromisoformat(value) - return value - - -def _coerce_date(value: Any) -> Any: - if isinstance(value, str): - return date.fromisoformat(value) - return value - - -def _coerce_float(value: Any) -> Any: - return None if value is None else float(value) - - -def _coerce_int(value: Any) -> Any: - return None if value is None else int(value) - - -def _coercer_for_column(col_def: ColumnDef) -> CursorCoercer: - """Pick a coercer matching the Postgres type chosen by ``column_mapper``.""" - if col_def.json_type == "number": - return _coerce_float - if col_def.json_type == "integer": - return _coerce_int - if col_def.json_type == "string": - if col_def.format == "date-time": - return _coerce_datetime - if col_def.format == "date": - return _coerce_date - return _coerce_identity - - -def _to_column_info(columns: list[Any]) -> list[ColumnInfo]: - return [ColumnInfo(name=c.name, type=c.json_type, required=c.required) for c in columns] - - -class PostgresFieldDefinitionReader: - """Builds field name → FieldType maps from registered schemas.""" - - def __init__(self, session: AsyncSession) -> None: - self.session = session - - async def get_all_field_types(self) -> dict[str, FieldType]: - stmt = select(schemas_table.c.fields) - result = await self.session.execute(stmt) - rows = result.mappings().all() - - field_map: dict[str, FieldType] = {} - for row in rows: - for field_def in row["fields"]: - name = field_def["name"] - field_type = FieldType(field_def["type"]) - if name in field_map and field_map[name] != field_type: - raise ValidationError( - f"Conflicting types for field '{name}': " - f"'{field_map[name]}' vs '{field_type}'", - field=name, - ) - field_map[name] = field_type - - return field_map - - async def get_fields_for_schema(self, schema_id: SchemaId) -> dict[str, FieldType]: - stmt = select(schemas_table.c.fields).where( - schemas_table.c.id == schema_id.id.root, - schemas_table.c.version == schema_id.version.root, - ) - result = await self.session.execute(stmt) - row = result.mappings().first() - if row is None: - return {} - return {f["name"]: FieldType(f["type"]) for f in row["fields"]} - - -class PostgresDiscoveryReadStore: - """Compiles FilterExpr trees into SQLAlchemy queries over records / metadata / features.""" - - def __init__(self, session: AsyncSession) -> None: - self.session = session - - async def search_records( - self, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - convention_srn: ConventionSRN | None, - text_fields: list[str], - q: str | None, - sort: str, - order: SortOrder, - cursor: dict[str, Any] | None, - limit: int, - field_types: dict[str, FieldType] | None = None, - ) -> list[RecordSummary]: - t = records_table - ft_map = field_types or {} - - metadata_table = None - metadata_schema: MetadataSchema | None = None - if schema_id is not None: - catalog = await self._metadata_catalog_for(schema_id) - if catalog is not None: - metadata_schema = MetadataSchema.model_validate(catalog["metadata_schema"]) - metadata_table = build_metadata_table(catalog["pg_table"], metadata_schema) - - feature_joins = await self._collect_feature_joins(filter_expr) - - conditions: list[Any] = [] - - if convention_srn is not None: - conditions.append(t.c.convention_srn == str(convention_srn)) - - if filter_expr is not None: - conditions.append( - self._compile_filter_for_records( - filter_expr, - records_t=t, - metadata_t=metadata_table, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - ) - - if q and text_fields and metadata_table is not None and metadata_schema is not None: - pattern = f"%{_escape_like(q)}%" - text_col_names = {c.name for c in metadata_schema.columns if c.json_type == "string"} - text_clauses = [ - cast(metadata_table.c[name], String).ilike(pattern, escape="\\") - for name in text_fields - if name in text_col_names - ] - if text_clauses: - conditions.append(or_(*text_clauses)) - - # Sort expression + matching cursor-value coercer - if sort == "published_at": - sort_expr = t.c.published_at - coerce_cursor: CursorCoercer = _coerce_datetime - elif metadata_table is not None and sort in metadata_table.c: - col = metadata_table.c[sort] - if ft_map.get(sort) == FieldType.NUMBER: - sort_expr = cast(col, Float) - coerce_cursor = _coerce_float - elif ft_map.get(sort) == FieldType.DATE: - sort_expr = cast(col, Date) - coerce_cursor = _coerce_date - else: - sort_expr = col - coerce_cursor = _coerce_identity - else: - sort_expr = t.c.published_at - coerce_cursor = _coerce_datetime - - is_desc = order == SortOrder.DESC - page = KeysetPage( - [ - SortKey(sort_expr, descending=is_desc, nulls_last=True), - SortKey(t.c.srn, descending=is_desc), - ] - ) - order_clauses = page.order_by() - if cursor is not None: - sort_value = coerce_cursor(cursor["s"]) - conditions.append(page.after((sort_value, cursor["id"]))) - - where_clause = and_(*conditions) if conditions else true() - - if metadata_table is not None and metadata_schema is not None: - select_cols = [t.c.srn, t.c.published_at] + [ - metadata_table.c[c.name].label(c.name) for c in metadata_schema.columns - ] - stmt = select(*select_cols).select_from( - t.join(metadata_table, metadata_table.c.record_srn == t.c.srn) - ) - else: - # No schema pinned — project the canonical JSONB metadata column. - # Typed tables are a query-optimized projection; JSONB remains the - # authoritative source for presentation (and for cross-schema - # listings where no single typed table applies). - stmt = select(t.c.srn, t.c.published_at, t.c.metadata) - - for hook, ft in feature_joins.items(): - stmt = stmt.join(ft, ft.c.record_srn == t.c.srn, isouter=True) - - stmt = stmt.where(where_clause).order_by(*order_clauses).limit(limit) - - result = await self.session.execute(stmt) - summaries: list[RecordSummary] = [] - if metadata_table is not None and metadata_schema is not None: - for row in result.mappings(): - meta = {c.name: row[c.name] for c in metadata_schema.columns if c.name in row} - summaries.append( - RecordSummary( - srn=RecordSRN.parse(row["srn"]), - published_at=row["published_at"], - metadata=meta, - ) - ) - else: - for row in result.mappings(): - summaries.append( - RecordSummary( - srn=RecordSRN.parse(row["srn"]), - published_at=row["published_at"], - metadata=row.get("metadata") or {}, - ) - ) - return summaries - - async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: - stmt = select( - feature_tables_table.c.hook_name, - feature_tables_table.c.pg_table, - feature_tables_table.c.feature_schema, - ) - result = await self.session.execute(stmt) - catalog_rows = result.mappings().all() - - if not catalog_rows: - return [] - - parsed = [ - (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) - for row in catalog_rows - ] - - count_parts = [] - for hook_name, schema in parsed: - ft = build_feature_table(hook_name, schema) - count_parts.append( - select( - literal(hook_name).label("hook_name"), - func.count(func.distinct(ft.c.record_srn)).label("cnt"), - ).select_from(ft) - ) - counts_result = await self.session.execute(union_all(*count_parts)) - counts_by_hook = {r["hook_name"]: r["cnt"] for r in counts_result.mappings()} - - return [ - FeatureCatalogEntry( - hook_name=hook_name, - columns=_to_column_info(schema.columns), - record_count=counts_by_hook.get(hook_name, 0), - ) - for hook_name, schema in parsed - ] - - async def get_feature_table_schema(self, hook_name: str) -> FeatureCatalogEntry | None: - stmt = select( - feature_tables_table.c.hook_name, - feature_tables_table.c.feature_schema, - ).where(feature_tables_table.c.hook_name == hook_name) - result = await self.session.execute(stmt) - row = result.mappings().first() - if row is None: - return None - - schema = FeatureSchema.model_validate(row["feature_schema"]) - return FeatureCatalogEntry( - hook_name=row["hook_name"], - columns=_to_column_info(schema.columns), - record_count=0, - ) - - async def search_features( - self, - hook_name: str, - filter_expr: FilterExpr | None, - schema_id: SchemaId | None, - record_srn: RecordSRN | None, - sort: str, - order: SortOrder, - cursor: dict[str, Any] | None, - limit: int, - ) -> list[FeatureRow]: - pg_table_stmt = select( - feature_tables_table.c.feature_schema, - ).where(feature_tables_table.c.hook_name == hook_name) - pg_result = await self.session.execute(pg_table_stmt) - pg_row = pg_result.mappings().first() - if pg_row is None: - return [] - schema = FeatureSchema.model_validate(pg_row["feature_schema"]) - - ft = build_feature_table(hook_name, schema) - - metadata_table = None - metadata_schema: MetadataSchema | None = None - if schema_id is not None: - catalog = await self._metadata_catalog_for(schema_id) - if catalog is not None: - metadata_schema = MetadataSchema.model_validate(catalog["metadata_schema"]) - metadata_table = build_metadata_table(catalog["pg_table"], metadata_schema) - - feature_joins: dict[str, Any] = {} - if filter_expr is not None: - extra = await self._collect_feature_joins(filter_expr) - for hook, tbl in extra.items(): - if hook != hook_name: - feature_joins[hook] = tbl - - conditions: list[Any] = [] - - if record_srn is not None: - conditions.append(ft.c.record_srn == str(record_srn)) - - if filter_expr is not None: - conditions.append( - self._compile_filter_for_features( - filter_expr, - this_hook=hook_name, - this_ft=ft, - metadata_t=metadata_table, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - ) - - if sort == "id": - sort_expr = ft.c.id - coerce_cursor: CursorCoercer = _coerce_int - else: - sort_expr = ft.c[sort] - col_def = next((c for c in schema.columns if c.name == sort), None) - coerce_cursor = ( - _coercer_for_column(col_def) if col_def is not None else _coerce_identity - ) - - is_desc = order == SortOrder.DESC - page = KeysetPage( - [ - SortKey(sort_expr, descending=is_desc, nulls_last=True), - SortKey(ft.c.id, descending=is_desc), - ] - ) - order_clauses = page.order_by() - if cursor is not None: - sort_value = coerce_cursor(cursor["s"]) - conditions.append(page.after((sort_value, cursor["id"]))) - - where_clause = and_(*conditions) if conditions else true() - - stmt = select(ft.c.id, ft.c.record_srn, *data_columns(ft)) - select_from = ft - if metadata_table is not None: - select_from = select_from.join( - metadata_table, metadata_table.c.record_srn == ft.c.record_srn, isouter=True - ) - for hook, other_ft in feature_joins.items(): - select_from = select_from.join( - other_ft, other_ft.c.record_srn == ft.c.record_srn, isouter=True - ) - stmt = ( - stmt.select_from(select_from).where(where_clause).order_by(*order_clauses).limit(limit) - ) - - result = await self.session.execute(stmt) - feature_rows: list[FeatureRow] = [] - for row in result.mappings(): - row_dict = dict(row) - row_id = row_dict.pop("id") - rsrn = RecordSRN.parse(row_dict.pop("record_srn")) - feature_rows.append(FeatureRow(row_id=row_id, record_srn=rsrn, data=row_dict)) - - return feature_rows - - # ---------------- compilation helpers ---------------- - - async def _metadata_catalog_for(self, schema_id: SchemaId) -> dict[str, Any] | None: - """Look up the metadata table catalog row for a SchemaId.""" - stmt = select(metadata_tables_table).where( - metadata_tables_table.c.schema_id == schema_id.id.root, - metadata_tables_table.c.schema_major == schema_id.major, - ) - result = await self.session.execute(stmt) - row = result.mappings().first() - return dict(row) if row is not None else None - - async def _collect_feature_joins(self, filter_expr: FilterExpr | None) -> dict[str, Any]: - """Build {hook_name: SQLA Table} for every distinct feature ref in the tree.""" - if filter_expr is None: - return {} - hooks: set[str] = set() - for p in _iter_predicates(filter_expr): - if isinstance(p.field, FeatureFieldRef): - hooks.add(p.field.hook) - if not hooks: - return {} - stmt = select( - feature_tables_table.c.hook_name, - feature_tables_table.c.feature_schema, - ).where(feature_tables_table.c.hook_name.in_(hooks)) - result = await self.session.execute(stmt) - joins: dict[str, Any] = {} - for row in result.mappings(): - schema = FeatureSchema.model_validate(row["feature_schema"]) - joins[row["hook_name"]] = build_feature_table(row["hook_name"], schema) - missing = hooks - joins.keys() - if missing: - raise ValidationError( - f"Unknown feature hook(s): {sorted(missing)}", - field="filter", - code="unknown_hook", - ) - return joins - - def _compile_filter_for_records( - self, - expr: FilterExpr, - *, - records_t: Any, - metadata_t: Any, - metadata_schema: MetadataSchema | None, - feature_joins: dict[str, Any], - ) -> Any: - if isinstance(expr, Predicate): - return self._compile_predicate( - expr, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - if isinstance(expr, And): - return and_( - *[ - self._compile_filter_for_records( - op, - records_t=records_t, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Or): - return or_( - *[ - self._compile_filter_for_records( - op, - records_t=records_t, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Not): - inner = self._compile_filter_for_records( - expr.operand, - records_t=records_t, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - # Coalesce NULL → FALSE before negating so records with NULL - # feature/metadata values (including rows missing from outer- - # joined feature tables) survive a NOT predicate. Without this, - # ``NOT (score = 5)`` reads NULL for records with no score and - # three-valued logic silently drops them. - return not_(func.coalesce(inner, false())) - raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") - - def _compile_filter_for_features( - self, - expr: FilterExpr, - *, - this_hook: str, - this_ft: Any, - metadata_t: Any, - metadata_schema: MetadataSchema | None, - feature_joins: dict[str, Any], - ) -> Any: - if isinstance(expr, Predicate): - if isinstance(expr.field, MetadataFieldRef): - if metadata_t is None: - raise ValidationError( - f"Metadata ref {expr.field.dotted()!r} requires schema_id to be set.", - field=expr.field.dotted(), - code="metadata_ref_requires_schema", - ) - col = metadata_t.c[expr.field.field] - return _apply_scalar_op(col, expr.op, expr.value) - if not isinstance(expr.field, FeatureFieldRef): - raise TypeError(f"Unexpected field ref type: {type(expr.field).__name__}") - if expr.field.hook == this_hook: - col = this_ft.c[expr.field.column] - else: - tbl = feature_joins.get(expr.field.hook) - if tbl is None: - raise ValidationError( - f"Unknown feature hook '{expr.field.hook}'.", - field=expr.field.dotted(), - code="unknown_hook", - ) - col = tbl.c[expr.field.column] - return _apply_scalar_op(col, expr.op, expr.value) - if isinstance(expr, And): - return and_( - *[ - self._compile_filter_for_features( - op, - this_hook=this_hook, - this_ft=this_ft, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Or): - return or_( - *[ - self._compile_filter_for_features( - op, - this_hook=this_hook, - this_ft=this_ft, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - for op in expr.operands - ] - ) - if isinstance(expr, Not): - inner = self._compile_filter_for_features( - expr.operand, - this_hook=this_hook, - this_ft=this_ft, - metadata_t=metadata_t, - metadata_schema=metadata_schema, - feature_joins=feature_joins, - ) - # See ``_compile_filter_for_records`` — NULL → FALSE coalesce so - # NOT over outer-joined feature / optional metadata columns - # includes records with missing values. - return not_(func.coalesce(inner, false())) - raise ValidationError(f"Unsupported filter node: {type(expr).__name__}") - - def _compile_predicate( - self, - predicate: Predicate, - *, - metadata_t: Any, - metadata_schema: MetadataSchema | None, - feature_joins: dict[str, Any], - ) -> Any: - if isinstance(predicate.field, MetadataFieldRef): - if metadata_t is None or metadata_schema is None: - raise ValidationError( - f"Metadata predicate on {predicate.field.dotted()!r} requires " - "the request to pin a 'schema' ('@'). " - "Unscoped metadata filtering is not supported — the typed table " - "is the only filter path.", - field=predicate.field.dotted(), - code="schema_required_for_metadata_query", - ) - col = metadata_t.c[predicate.field.field] - return _apply_scalar_op(col, predicate.op, predicate.value) - - if not isinstance(predicate.field, FeatureFieldRef): - raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") - tbl = feature_joins.get(predicate.field.hook) - if tbl is None: - raise ValidationError( - f"Unknown feature hook '{predicate.field.hook}'.", - field=predicate.field.dotted(), - code="unknown_hook", - ) - col = tbl.c[predicate.field.column] - return _apply_scalar_op(col, predicate.op, predicate.value) - - -def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: - if op == FilterOperator.EQ: - return col == value - if op == FilterOperator.NEQ: - # Feature tables are outer-joined, so a missing feature row makes - # ``col`` NULL. Plain ``col != value`` yields NULL (falsy) and - # silently excludes those records from the result. Users reading - # ``!= X`` expect "anything except X, including missing", so treat - # NULL as non-equal explicitly. - return or_(col != value, col.is_(None)) - if op == FilterOperator.GT: - return col > value - if op == FilterOperator.GTE: - return col >= value - if op == FilterOperator.LT: - return col < value - if op == FilterOperator.LTE: - return col <= value - if op == FilterOperator.IN: - if not isinstance(value, list): - raise ValidationError( - "Operator 'in' requires a list value.", - field=col.key, - code="invalid_value_for_op", - ) - return col.in_(value) - if op == FilterOperator.CONTAINS: - return cast(col, String).ilike(f"%{_escape_like(str(value))}%", escape="\\") - if op == FilterOperator.IS_NULL: - return col.is_(None) - raise ValidationError( - f"Unsupported operator: {op}", field="filter", code="unsupported_operator" - ) - - -def _iter_predicates(expr: FilterExpr): - if isinstance(expr, Predicate): - yield expr - return - if isinstance(expr, Not): - yield from _iter_predicates(expr.operand) - return - if isinstance(expr, (And, Or)): - for op in expr.operands: - yield from _iter_predicates(op) diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 12873e2..9f5316b 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -27,12 +27,8 @@ from osa.domain.shared.port.event_repository import EventRepository from osa.domain.feature.port.feature_store import FeatureStore from osa.domain.validation.port.repository import ValidationRunRepository -from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader -from osa.domain.discovery.port.read_store import DiscoveryReadStore -from osa.infrastructure.persistence.adapter.discovery import ( - PostgresDiscoveryReadStore, - PostgresFieldDefinitionReader, -) +from osa.domain.data.port.data_read_store import DataReadStore +from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore from osa.infrastructure.persistence.adapter.readers import ( OntologyReaderAdapter, SchemaReaderAdapter, @@ -173,13 +169,10 @@ def get_record_service( feature_reader=feature_reader, ) - # Discovery adapters - discovery_read_store = provide( - PostgresDiscoveryReadStore, scope=Scope.UOW, provides=DiscoveryReadStore - ) - field_definition_reader = provide( - PostgresFieldDefinitionReader, scope=Scope.UOW, provides=FieldDefinitionReader - ) + # Data read surface adapter (unified /data/* engine) + @provide(scope=Scope.UOW, provides=DataReadStore) + def get_data_read_store(self, session: AsyncSession, config: Config) -> PostgresDataReadStore: + return PostgresDataReadStore(session=session, node_domain=Domain(config.domain)) # Record query handlers get_record_handler = provide(GetRecordHandler, scope=Scope.UOW) diff --git a/server/osa/sdk/index/__init__.py b/server/osa/sdk/index/__init__.py deleted file mode 100644 index 8c862ca..0000000 --- a/server/osa/sdk/index/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Index SDK - Protocols and types for pluggable storage backends.""" - -from osa.sdk.index.backend import StorageBackend -from osa.sdk.index.config import BackendConfig -from osa.sdk.index.result import QueryResult, SearchHit - -__all__ = [ - "StorageBackend", - "BackendConfig", - "QueryResult", - "SearchHit", -] diff --git a/server/osa/sdk/index/backend.py b/server/osa/sdk/index/backend.py deleted file mode 100644 index 5b8e029..0000000 --- a/server/osa/sdk/index/backend.py +++ /dev/null @@ -1,91 +0,0 @@ -"""StorageBackend protocol for pluggable index backends.""" - -from typing import Any, Protocol - -from osa.sdk.index.result import QueryResult - - -class StorageBackend(Protocol): - """Protocol for pluggable index storage backends. - - Implement this protocol to create custom storage backends - (e.g., vector search, keyword search, graph databases). - """ - - @property - def name(self) -> str: - """Unique name for this index instance.""" - ... - - async def ingest(self, srn: str, record: dict[str, Any]) -> None: - """Store a record in the index. - - Args: - srn: Structured Resource Name identifying the record. - record: The record metadata to index. - """ - ... - - async def delete(self, srn: str) -> None: - """Remove a record from the index. - - Args: - srn: Structured Resource Name of the record to remove. - """ - ... - - async def query(self, q: str, limit: int = 20) -> QueryResult: - """Execute a query and return structured results. - - Args: - q: The query string. - limit: Maximum number of results to return. - - Returns: - QueryResult containing matching hits. - """ - ... - - async def health(self) -> bool: - """Check if the backend is operational. - - Returns: - True if the backend is healthy, False otherwise. - """ - ... - - async def count(self) -> int: - """Return the number of documents in the index. - - Returns: - Number of indexed documents. - """ - ... - - async def flush(self) -> None: - """Flush any buffered operations to storage. - - For backends that buffer writes for batch efficiency (e.g., batch embedding - generation), this method ensures all pending records are persisted. - - Backends without internal buffering can implement this as a no-op. - - Deprecated: With event-level batching, backends should be stateless. - This method will be removed in a future version. - """ - ... - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - """Index a batch of records atomically. - - All records in the batch must succeed or all must fail together. - This method should be optimized for batch operations (e.g., batch - embedding generation, bulk database inserts). - - Args: - records: List of (srn, metadata) tuples to index - - Raises: - Exception: If any record fails to index (none are committed) - """ - ... diff --git a/server/osa/sdk/index/config.py b/server/osa/sdk/index/config.py deleted file mode 100644 index 8a96628..0000000 --- a/server/osa/sdk/index/config.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Base configuration for storage backends.""" - -from pydantic import BaseModel - - -class BackendConfig(BaseModel): - """Base configuration for storage backends. - - Extend this class for vendor-specific configuration. - """ - - pass diff --git a/server/osa/sdk/index/result.py b/server/osa/sdk/index/result.py deleted file mode 100644 index d1bb063..0000000 --- a/server/osa/sdk/index/result.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Type-safe result types for index queries.""" - -from typing import Any - -from pydantic import BaseModel - - -class SearchHit(BaseModel, frozen=True): - """A single search result from an index query.""" - - srn: str - score: float - metadata: dict[str, Any] - - -class QueryResult(BaseModel, frozen=True): - """Result of an index query.""" - - hits: list[SearchHit] - total: int - query: str diff --git a/server/pyproject.toml b/server/pyproject.toml index c9b69e2..c63ca30 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "pyjwt>=2.11.0", "openpyxl>=3.1.5", "python-multipart>=0.0.22", + "slowapi>=0.1.9", ] [project.entry-points."osa.sources"] diff --git a/server/tests/contract/test_worker_lifecycle.py b/server/tests/contract/test_worker_lifecycle.py deleted file mode 100644 index 9032429..0000000 --- a/server/tests/contract/test_worker_lifecycle.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Contract tests for WorkerPool lifecycle in FastAPI lifespan. - -Tests for Phase 7: Migration. -""" - -import asyncio -from datetime import UTC, datetime -from typing import Any -from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.keyword_index_handler import KeywordIndexHandler -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.event import ClaimResult, EventId -from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion -from osa.infrastructure.event.worker import WorkerPool - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str): - self._name = name - self.ingested: list[tuple[str, dict]] = [] - - @property - def name(self) -> str: - return self._name - - async def ingest_batch(self, records: list[tuple[str, dict]]) -> None: - self.ingested.extend(records) - - -def make_mock_container(handler_type: type, handler_instance: Any): - """Create a mock DI container that provides scoped dependencies. - - Creates a container mock that returns scoped Outbox and AsyncSession - when called as an async context manager with scope parameter. - """ - from osa.domain.shared.outbox import Outbox - - outbox = AsyncMock() - outbox.claim.return_value = ClaimResult(deliveries=[], claimed_at=datetime.now(UTC)) - outbox.reset_stale_claims.return_value = 0 - - session = AsyncMock() - session.commit = AsyncMock() - session.rollback = AsyncMock() - - async def get_dependency(cls: type) -> Any: - """Return the appropriate dependency based on the requested class.""" - if cls == Outbox: - return outbox - if cls == handler_type: - return handler_instance - return session - - # Create scope that returns dependencies - scope = AsyncMock() - scope.get = AsyncMock(side_effect=get_dependency) - - # Create async context manager - context = MagicMock() - context.__aenter__ = AsyncMock(return_value=scope) - context.__aexit__ = AsyncMock(return_value=None) - - # Container callable returns the context manager - container = MagicMock() - container.return_value = context - - return container, outbox, session - - -class TestWorkerPoolLifecycle: - """Tests for WorkerPool lifecycle management.""" - - @pytest.mark.asyncio - async def test_worker_pool_starts_and_stops_cleanly(self): - """WorkerPool should start and stop without errors.""" - vector_backend = FakeBackend("vector") - keyword_backend = FakeBackend("keyword") - registry = IndexRegistry({"vector": vector_backend, "keyword": keyword_backend}) - - # Create handler instance - vector_handler = VectorIndexHandler(indexes=registry) - - container, outbox, session = make_mock_container(VectorIndexHandler, vector_handler) - - pool = WorkerPool(container=container, stale_claim_interval=60.0) - pool.register(VectorIndexHandler) - pool.register(KeywordIndexHandler) - - # Start - await pool.start() - await asyncio.sleep(0.05) - - # Verify workers are running - assert len(pool.workers) == 2 - for worker in pool.workers: - assert worker._task is not None - assert not worker._task.done() - - # Stop - await pool.stop() - - # Verify workers are stopped - for worker in pool.workers: - assert worker._shutdown is True - - @pytest.mark.asyncio - async def test_worker_pool_as_context_manager(self): - """WorkerPool should work as async context manager.""" - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - vector_handler = VectorIndexHandler(indexes=registry) - - container, outbox, session = make_mock_container(VectorIndexHandler, vector_handler) - - pool = WorkerPool(container=container, stale_claim_interval=60.0) - pool.register(VectorIndexHandler) - - async with pool: - # Workers should be running - assert pool.workers[0]._task is not None - await asyncio.sleep(0.02) - - # After exit, workers should be stopped - assert pool.workers[0]._shutdown is True - - -class TestIndexHandlers: - """Tests for concrete index handlers.""" - - @pytest.mark.asyncio - async def test_vector_handler_processes_batch(self): - """VectorIndexHandler should process IndexRecord events in batches.""" - backend = FakeBackend("vector") - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Create test events - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(f"rec-{i}"), - version=RecordVersion(1), - ), - metadata={"title": f"Record {i}"}, - ) - for i in range(5) - ] - - # Process - await handler.handle_batch(events) - - # Verify backend received all records - assert len(backend.ingested) == 5 - - @pytest.mark.asyncio - async def test_keyword_handler_processes_individually(self): - """KeywordIndexHandler should process IndexRecord events one at a time.""" - backend = FakeBackend("keyword") - registry = IndexRegistry({"keyword": backend}) - handler = KeywordIndexHandler(indexes=registry) - - # batch_size should be 1 - assert KeywordIndexHandler.__batch_size__ == 1 - - # Create test event - event = IndexRecord( - id=EventId(uuid4()), - backend_name="keyword", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId("rec-1"), - version=RecordVersion(1), - ), - metadata={"title": "Record 1"}, - ) - - # Process - await handler.handle(event) - - # Verify - assert len(backend.ingested) == 1 - - @pytest.mark.asyncio - async def test_handler_raises_on_backend_failure(self): - """Handlers should raise when backend fails (Worker handles retry).""" - # Backend that fails - failing_backend = AsyncMock() - failing_backend.name = "vector" - failing_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend error")) - - registry = IndexRegistry({"vector": failing_backend}) - handler = VectorIndexHandler(indexes=registry) - - event = IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId("rec-1"), - version=RecordVersion(1), - ), - metadata={"title": "Record 1"}, - ) - - # Process - should raise (Worker handles retry) - with pytest.raises(Exception, match="Backend error"): - await handler.handle_batch([event]) diff --git a/server/tests/integration/persistence/test_discovery_pagination.py b/server/tests/integration/persistence/test_discovery_pagination.py deleted file mode 100644 index 49b2fb8..0000000 --- a/server/tests/integration/persistence/test_discovery_pagination.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Integration tests for discovery keyset pagination against real Postgres. - -Regression coverage for a production bug where paginating past page 1 raised -``operator does not exist: timestamp with time zone < character varying`` -because the cursor's sort value round-tripped through JSON as a plain string -and was bound as ``VARCHAR`` against the typed ``records.published_at`` column. -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -import pytest -from sqlalchemy.ext.asyncio import AsyncSession - -from osa.domain.discovery.model.value import SortOrder, decode_cursor, encode_cursor -from osa.infrastructure.persistence.adapter.discovery import PostgresDiscoveryReadStore -from osa.infrastructure.persistence.tables import records_table - - -async def _insert_record(session: AsyncSession, srn: str, published_at: datetime) -> None: - await session.execute( - records_table.insert().values( - srn=srn, - convention_srn="urn:osa:localhost:conv:test@1.0.0", - schema_id="test", - schema_version="1.0.0", - source={"type": "test", "id": srn}, - metadata={}, - published_at=published_at, - ) - ) - await session.commit() - - -@pytest.mark.asyncio -class TestDiscoveryPaginationPublishedAt: - async def test_second_page_with_published_at_cursor(self, pg_session: AsyncSession) -> None: - """Fetching page 2 with a cursor must not trip the timestamptz/varchar - mismatch — the bug manifested only on requests that supplied a cursor.""" - store = PostgresDiscoveryReadStore(pg_session) - base = datetime(2026, 4, 7, 9, 0, 0, tzinfo=UTC) - records = [(f"urn:osa:localhost:rec:page-{i}@1", base.replace(second=i)) for i in range(3)] - for srn, ts in records: - await _insert_record(pg_session, srn, ts) - - first_page = await store.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=2, - ) - assert len(first_page) == 2 - - # Encode + decode the cursor the same way the service does — this is the - # round-trip that previously produced a VARCHAR bind. - last = first_page[-1] - cursor_str = encode_cursor(last.published_at.isoformat(), str(last.srn)) - decoded = decode_cursor(cursor_str) - - second_page = await store.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=decoded, - limit=2, - ) - - assert len(second_page) == 1 - returned = {str(r.srn) for r in first_page} | {str(r.srn) for r in second_page} - assert returned == {srn for srn, _ in records} diff --git a/server/osa/domain/export/event/__init__.py b/server/tests/integration/semantics/__init__.py similarity index 100% rename from server/osa/domain/export/event/__init__.py rename to server/tests/integration/semantics/__init__.py diff --git a/server/tests/integration/semantics/test_reserved_name_registration.py b/server/tests/integration/semantics/test_reserved_name_registration.py new file mode 100644 index 0000000..d1d307e --- /dev/null +++ b/server/tests/integration/semantics/test_reserved_name_registration.py @@ -0,0 +1,46 @@ +"""Integration test for reserved-name enforcement at schema registration (T066). + +Exercises the real schema-registration service path against Postgres: a schema +whose id is a reserved ``/data/`` path slot (``records`` / ``datasets``) is +rejected with ``ReservedNameError`` (``code="reserved_name"``), which the API +layer maps to HTTP 400. The aggregate invariant is the single source of truth; +this confirms it fires through the persistence-wired service, not just the unit. + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.semantics.service.schema import SchemaService +from osa.domain.shared.error import ReservedNameError +from osa.domain.shared.model.srn import Domain, SchemaIdentifier +from osa.infrastructure.persistence.repository.ontology import PostgresOntologyRepository +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) + + +def _service(session: AsyncSession) -> SchemaService: + return SchemaService( + schema_repo=PostgresSemanticsSchemaRepository(session), + ontology_repo=PostgresOntologyRepository(session), + node_domain=Domain("localhost"), + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("reserved", ["records", "datasets"]) +async def test_registering_reserved_schema_id_rejected( + pg_engine: AsyncEngine, pg_session: AsyncSession, reserved: str +): + service = _service(pg_session) + with pytest.raises(ReservedNameError) as exc: + await service.create_schema( + id=SchemaIdentifier(reserved), + title="should be rejected", + version="1.0.0", + fields=[], + ) + assert exc.value.code == "reserved_name" + assert exc.value.name == reserved diff --git a/server/tests/integration/test_crash_recovery.py b/server/tests/integration/test_crash_recovery.py deleted file mode 100644 index 18ab740..0000000 --- a/server/tests/integration/test_crash_recovery.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Integration tests for crash recovery in event processing. - -Tests that events are not lost when the worker is interrupted mid-processing. -""" - -from typing import Any -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion - - -class TrackingBackend: - """Backend that tracks all operations for verification.""" - - def __init__(self, name: str, fail_at: int | None = None): - self._name = name - self._fail_at = fail_at - self._call_count = 0 - self.indexed_records: list[tuple[str, dict]] = [] - - @property - def name(self) -> str: - return self._name - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - """Track records and optionally fail at specific call.""" - self._call_count += 1 - if self._fail_at is not None and self._call_count >= self._fail_at: - raise Exception("Simulated crash") - self.indexed_records.extend(records) - - -def make_index_record(backend_name: str, record_id: str) -> IndexRecord: - """Create an IndexRecord for testing.""" - return IndexRecord( - id=EventId(uuid4()), - backend_name=backend_name, - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(record_id), - version=RecordVersion(1), - ), - metadata={"id": record_id, "title": f"Record {record_id}"}, - ) - - -class FakeOutbox: - """Simulates outbox with pending events.""" - - def __init__(self, events: list[IndexRecord]): - self.pending = list(events) - self.delivered: list[EventId] = [] - self.failed: dict[EventId, str] = {} - - async def fetch_pending(self, limit: int) -> list[IndexRecord]: - """Return pending events.""" - batch = self.pending[:limit] - return batch - - async def mark_delivered(self, event_id: EventId) -> None: - """Mark event as delivered.""" - self.delivered.append(event_id) - self.pending = [e for e in self.pending if e.id != event_id] - - async def mark_failed(self, event_id: EventId, error: str) -> None: - """Mark event as failed.""" - self.failed[event_id] = error - - -class TestCrashRecoveryScenarios: - """Tests for crash recovery scenarios.""" - - @pytest.mark.asyncio - async def test_events_remain_pending_on_crash(self): - """Events should remain pending if processing crashes before commit.""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(10)] - outbox = FakeOutbox(events) - - # Simulate a backend that fails mid-batch - backend = TrackingBackend("vector", fail_at=1) # Fail on first call - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - Try to process, expecting failure - with pytest.raises(Exception, match="Simulated crash"): - await handler.handle_batch(events) - - # Assert - Events should remain pending (outbox unchanged) - assert len(outbox.pending) == 10 # All still pending - assert len(outbox.delivered) == 0 # None delivered - assert len(backend.indexed_records) == 0 # None indexed - - @pytest.mark.asyncio - async def test_recovery_processes_all_pending_events(self): - """After recovery, all pending events should be processed.""" - # Arrange - Create events and simulate they were fetched but not committed - events = [make_index_record("vector", f"record-{i}") for i in range(10)] - outbox = FakeOutbox(events) - - # First attempt: backend fails - failing_backend = TrackingBackend("vector", fail_at=1) - failing_registry = IndexRegistry({"vector": failing_backend}) - failing_handler = VectorIndexHandler(indexes=failing_registry) - - with pytest.raises(Exception): - await failing_handler.handle_batch(events) - - # Events still pending - assert len(outbox.pending) == 10 - - # Second attempt (recovery): backend works - working_backend = TrackingBackend("vector") - working_registry = IndexRegistry({"vector": working_backend}) - working_handler = VectorIndexHandler(indexes=working_registry) - - # Act - Retry processing - await working_handler.handle_batch(outbox.pending) - - # Assert - All events processed - assert len(working_backend.indexed_records) == 10 - - @pytest.mark.asyncio - async def test_partial_batch_failure_is_atomic(self): - """If batch fails partway, none of the events should be committed.""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(5)] - - # Backend that succeeds on first record but throws exception - class PartialFailBackend: - def __init__(self): - self.name = "vector" - self.processed: list[tuple[str, dict]] = [] - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - # Process some, then fail - for i, (srn, meta) in enumerate(records): - if i >= 2: - raise Exception("Partial failure at record 2") - self.processed.append((srn, meta)) - - backend = PartialFailBackend() - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - with pytest.raises(Exception, match="Partial failure"): - await handler.handle_batch(events) - - # Assert - Some records were processed but batch should be atomic - # In production, the outbox marks all events as failed together - # The handler correctly propagates the error for retry - assert len(backend.processed) == 2 # Backend saw 2 before crash - - @pytest.mark.asyncio - async def test_idempotent_reprocessing(self): - """Reprocessing the same events should be safe (idempotent).""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(5)] - - backend = TrackingBackend("vector") - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - Process twice - await handler.handle_batch(events) - await handler.handle_batch(events) # Reprocess same events - - # Assert - Records should be in backend (upsert semantics handle duplicates) - # The backend receives all records from both batches - assert len(backend.indexed_records) == 10 # 5 + 5 - - # In production, ChromaDB's upsert handles this - same ID overwrites - # The important thing is no data corruption - - @pytest.mark.asyncio - async def test_crash_safe_no_in_memory_buffer(self): - """Verify there's no in-memory buffer that could lose data.""" - # Arrange - events = [make_index_record("vector", f"record-{i}") for i in range(3)] - - class StatelessBackend: - """Backend with no internal state.""" - - def __init__(self): - self.name = "vector" - self.batch_calls: list[list[tuple[str, dict]]] = [] - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - # Immediately persist (no buffering) - self.batch_calls.append(list(records)) - - backend = StatelessBackend() - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Act - await handler.handle_batch(events) - - # Assert - All records in single batch call, immediately persisted - assert len(backend.batch_calls) == 1 - assert len(backend.batch_calls[0]) == 3 - - # No flush needed - data is persisted immediately - # If we crashed right after ingest_batch, all 3 records are safe diff --git a/server/tests/integration/test_data_features_postgres.py b/server/tests/integration/test_data_features_postgres.py new file mode 100644 index 0000000..ffae4e3 --- /dev/null +++ b/server/tests/integration/test_data_features_postgres.py @@ -0,0 +1,219 @@ +"""Integration tests for PostgresDataReadStore feature-table streaming (US5). + +Exercises the FEATURE branch of the unified ``/data/`` engine against a real +Postgres: a schema with a registered hook produces a ``features.`` table; +the engine streams it, filters on its columns, and surfaces it in the catalog +and manifest scoped to the owning schema. + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +from datetime import UTC, datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.data.model.filter import FilterOperator, FeatureFieldRef, Predicate +from osa.domain.data.model.query_plan import ( + PaginationParams, + QueryPlan, + TableKind, +) +from osa.domain.data.model.query_plan import TableKind as TK +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.infrastructure.persistence.feature_store import PostgresFeatureStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) +from osa.infrastructure.persistence.tables import conventions_table + +from tests.integration.conftest import seed_record + +SCHEMA = SchemaId.parse("compound@1.0.0") +HOOK = "chem_features" + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +def _feature_columns() -> list[ColumnDef]: + return [ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ] + + +async def _setup_schema(engine: AsyncEngine, session: AsyncSession) -> PostgresMetadataStore: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + return store + + +async def _register_hook(engine: AsyncEngine, session: AsyncSession, hook_name: str = HOOK) -> None: + """Link the schema → hook via a convention row, then create its feature table. + + The read store only reads ``hooks[*].name`` from the convention, so a + name-only hooks payload is sufficient to scope the feature to the schema. + """ + await session.execute( + conventions_table.insert().values( + srn=f"urn:osa:localhost:conv:{hook_name}@1.0.0", + title="compound conv", + description=None, + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + file_requirements={}, + hooks=[{"name": hook_name}], + source=None, + created_at=datetime.now(UTC), + ) + ) + await session.commit() + feature_store = PostgresFeatureStore(engine, session) + await feature_store.create_table(hook_name, _feature_columns()) + + +async def _publish(engine: AsyncEngine, store: PostgresMetadataStore, rid: str) -> RecordSRN: + srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") + await seed_record( + engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": "Homo sapiens"}, + published_at=datetime(2026, 1, 1, tzinfo=UTC), + ) + await store.insert(SCHEMA, srn, {"species": "Homo sapiens"}) + return srn + + +def _feature_plan(filter_expr=None, limit=50) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name=HOOK, + filter=filter_expr, + pagination=PaginationParams(limit=limit), + ) + + +async def _drain(store: PostgresDataReadStore, plan: QueryPlan) -> list[dict]: + return [dict(row) async for row in store.stream_rows(plan)] + + +@pytest.mark.asyncio +class TestStreamFeatures: + async def test_streams_feature_rows_with_data_columns( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features( + HOOK, str(srn), [{"score": 0.9, "label": "high"}, {"score": 0.1, "label": "low"}] + ) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rows = await _drain(rs, _feature_plan()) + + assert len(rows) == 2 + # Default FEATURE sort is id asc → first inserted row first. + first = rows[0] + assert first["record_srn"] == str(srn) + assert first["score"] == 0.9 + assert first["label"] == "high" + assert "id" in first and "created_at" in first + + async def test_feature_filter_narrows_results( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features( + HOOK, str(srn), [{"score": 0.9, "label": "high"}, {"score": 0.1, "label": "low"}] + ) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + plan = _feature_plan( + filter_expr=Predicate( + field=FeatureFieldRef(hook=HOOK, column="score"), + op=FilterOperator.GTE, + value=0.5, + ) + ) + rows = await _drain(rs, plan) + assert [r["label"] for r in rows] == ["high"] + + async def test_unknown_feature_raises_not_found( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="nonexistent", + ) + with pytest.raises(NotFoundError): + await _drain(rs, plan) + + +@pytest.mark.asyncio +class TestFeatureCatalogAndManifest: + async def test_manifest_includes_feature_resource( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features(HOOK, str(srn), [{"score": 0.9, "label": "high"}]) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + manifest = await rs.get_schema_manifest(SCHEMA) + assert manifest is not None + feature_res = next(t for t in manifest.table_resources if t.name == HOOK) + assert feature_res.kind == TK.FEATURE + assert feature_res.row_count == 1 + col_names = [c.name for c in feature_res.columns] + assert col_names[:3] == ["id", "record_srn", "created_at"] + assert "score" in col_names and "label" in col_names + assert feature_res.formats == ["", "csv", "csv.gz"] + + async def test_catalog_lists_feature_resource( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + catalog = await rs.get_node_catalog() + entry = next(e for e in catalog.schemas if e.id == "compound") + names = {(t.name, t.kind) for t in entry.table_resources} + assert ("records", TK.RECORDS) in names + assert (HOOK, TK.FEATURE) in names diff --git a/server/tests/integration/test_data_read_store_postgres.py b/server/tests/integration/test_data_read_store_postgres.py new file mode 100644 index 0000000..3882eaa --- /dev/null +++ b/server/tests/integration/test_data_read_store_postgres.py @@ -0,0 +1,265 @@ +"""Integration tests for PostgresDataReadStore against a real Postgres. + +Exercises the unified /data/ engine end-to-end: records streaming via the +server-side cursor, filtered streaming, single-record-by-id, node catalog, and +schema manifest. Mirrors the discovery integration tests' seeding pattern +(schema row + dynamic metadata table + seeded records). + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +from datetime import UTC, datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.data.model.filter import FilterOperator, MetadataFieldRef, Predicate +from osa.domain.data.model.query_plan import ( + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, +) +from osa.domain.data.model.query_plan import TableKind as TK +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) + +from tests.integration.conftest import seed_record + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + FieldDefinition( + name="mw", + type=FieldType.NUMBER, + required=False, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +async def _setup_schema(engine: AsyncEngine, session: AsyncSession) -> PostgresMetadataStore: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + repo = PostgresSemanticsSchemaRepository(session) + await repo.save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + return store + + +async def _publish( + engine: AsyncEngine, + store: PostgresMetadataStore, + rid: str, + species: str, + mw: float, + published_at: datetime, +) -> None: + srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") + await seed_record( + engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": species, "mw": mw}, + published_at=published_at, + ) + await store.insert(SCHEMA, srn, {"species": species, "mw": mw}) + + +def _records_plan(filter_expr=None, limit=50, cursor=None) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + filter=filter_expr, + pagination=PaginationParams(limit=limit, cursor=cursor), + ) + + +async def _drain(store: PostgresDataReadStore, plan: QueryPlan) -> list[dict]: + return [dict(row) async for row in store.stream_rows(plan)] + + +@pytest.mark.asyncio +class TestStreamRecords: + async def test_streams_all_rows_flattened( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "rec1", "Homo sapiens", 1.5, datetime(2026, 1, 3, tzinfo=UTC) + ) + await _publish( + pg_engine, store, "rec2", "Mus musculus", 2.0, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rows = await _drain(rs, _records_plan()) + + assert len(rows) == 2 + # Default RECORDS sort is created_at desc → r1 first. + first = rows[0] + assert first["id"] == "rec1" + assert first["srn"] == "urn:osa:localhost:rec:rec1@1" + assert first["schema_id"] == "compound@1.0.0" + assert first["version"] == 1 + assert first["species"] == "Homo sapiens" + assert first["mw"] == 1.5 + assert "created_at" in first + + async def test_metadata_filter_narrows_results( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "rec1", "Homo sapiens", 1.5, datetime(2026, 1, 3, tzinfo=UTC) + ) + await _publish( + pg_engine, store, "rec2", "Mus musculus", 2.0, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + plan = _records_plan( + filter_expr=Predicate( + field=MetadataFieldRef(field="species"), + op=FilterOperator.EQ, + value="Homo sapiens", + ) + ) + rows = await _drain(rs, plan) + assert [r["id"] for r in rows] == ["rec1"] + + async def test_empty_result_streams_nothing( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + assert await _drain(rs, _records_plan()) == [] + + +@pytest.mark.asyncio +class TestGetRecordById: + async def test_resolves_by_bare_id_with_srn( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "abc", "Homo sapiens", 1.5, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rec = await rs.get_record_by_id(RecordId("abc"), None) + assert rec is not None + assert str(rec.id) == "abc" + assert str(rec.srn) == "urn:osa:localhost:rec:abc@1" + assert rec.schema_id.render() == "compound@1.0.0" + assert rec.version == 1 + assert rec.metadata["species"] == "Homo sapiens" + + async def test_unknown_id_returns_none(self, pg_engine: AsyncEngine, pg_session: AsyncSession): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + assert await rs.get_record_by_id(RecordId("missing"), None) is None + + +@pytest.mark.asyncio +class TestCatalogAndManifest: + async def test_node_catalog_lists_schema( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + catalog = await rs.get_node_catalog() + assert catalog.node_domain == "localhost" + ids = {(e.id, e.version) for e in catalog.schemas} + assert ("compound", "1.0.0") in ids + + async def test_manifest_shape(self, pg_engine: AsyncEngine, pg_session: AsyncSession): + store = await _setup_schema(pg_engine, pg_session) + await _publish( + pg_engine, store, "rec1", "Homo sapiens", 1.5, datetime(2026, 1, 1, tzinfo=UTC) + ) + await pg_session.commit() + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + manifest = await rs.get_schema_manifest(SCHEMA) + assert manifest is not None + assert manifest.id == "compound" + assert manifest.version == "1.0.0" + assert {f.name for f in manifest.fields} == {"species", "mw"} + records_res = next(t for t in manifest.table_resources if t.name == "records") + assert records_res.kind == TK.RECORDS + assert records_res.row_count == 1 + col_names = [c.name for c in records_res.columns] + # implicit columns precede the declared fields + assert col_names[:5] == ["id", "srn", "schema_id", "version", "created_at"] + assert "species" in col_names and "mw" in col_names + assert records_res.formats == ["", "csv", "csv.gz"] + + async def test_get_latest_schema_id(self, pg_engine: AsyncEngine, pg_session: AsyncSession): + await _setup_schema(pg_engine, pg_session) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + resolved = await rs.get_latest_schema_id("compound") + assert resolved is not None + assert resolved.render() == "compound@1.0.0" + assert await rs.get_latest_schema_id("nonexistent") is None + + +@pytest.mark.asyncio +class TestStreamPaginationOrder: + async def test_cursor_advances_without_overlap( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + for i in range(5): + await _publish( + pg_engine, + store, + f"rec{i}", + "Homo sapiens", + float(i), + datetime(2026, 1, 1 + i, tzinfo=UTC), + ) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + + # Sort by created_at desc (default). Page 1: take 2, derive cursor from row 2. + from osa.domain.data.model.query_plan import PaginationCursor, encode_cursor + + all_rows = await _drain(rs, _records_plan()) + assert [r["id"] for r in all_rows] == ["rec4", "rec3", "rec2", "rec1", "rec0"] + + # Cursor after the 2nd row (r3) → next page should start at r2. + cursor = encode_cursor(all_rows[1]["created_at"], all_rows[1]["srn"]) + plan2 = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination=PaginationParams(limit=50, cursor=PaginationCursor(value=cursor)), + sort=[SortSpec(column="created_at", direction=SortDirection.DESC)], + ) + page2 = await _drain(rs, plan2) + assert [r["id"] for r in page2] == ["rec2", "rec1", "rec0"] diff --git a/server/tests/integration/test_data_routes_e2e_postgres.py b/server/tests/integration/test_data_routes_e2e_postgres.py new file mode 100644 index 0000000..d2e14b3 --- /dev/null +++ b/server/tests/integration/test_data_routes_e2e_postgres.py @@ -0,0 +1,527 @@ +"""End-to-end HTTP tests for the /data/ routes against a real Postgres. + +Drives the full stack — FastAPI routing, DishkaRoute DI, the streaming +response, the serializers, and the SQL — through an ASGI client. Critically +validates that the request-scoped DB session stays alive through +``StreamingResponse`` iteration (the DishkaRoute + streaming interaction). + +Lifespan is intentionally NOT run (no worker pool needed): ``create_app`` +attaches the Dishka container to app state, so DI resolves without it. + +Skips unless OSA_DATABASE__URL points at PostgreSQL. +""" + +import gzip +import os +from datetime import UTC, datetime + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.srn import RecordSRN, SchemaId +from osa.infrastructure.persistence.feature_store import PostgresFeatureStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) +from osa.infrastructure.persistence.tables import conventions_table + +from tests.integration.conftest import seed_record + +# create_app() reads Config() at import/call time; localhost domain needs a base URL. +os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") +os.environ.setdefault("OSA_AUTH__JWT__SECRET", "test-secret-for-integration-tests-minimum-32-chars") + +if "postgresql" not in os.environ.get("OSA_DATABASE__URL", ""): + pytest.skip("OSA_DATABASE__URL not set to PostgreSQL", allow_module_level=True) + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + FieldDefinition( + name="mw", + type=FieldType.NUMBER, + required=False, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +async def _seed(engine: AsyncEngine, session: AsyncSession, n: int) -> None: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + for i in range(n): + srn = RecordSRN.parse(f"urn:osa:localhost:rec:rec{i:03d}@1") + await seed_record( + engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": "Homo sapiens", "mw": float(i)}, + published_at=datetime(2026, 1, 1 + i, tzinfo=UTC), + ) + await store.insert(SCHEMA, srn, {"species": "Homo sapiens", "mw": float(i)}) + await session.commit() + + +HOOK = "chem_features" + + +async def _seed_feature( + engine: AsyncEngine, session: AsyncSession, record_srn: RecordSRN, n: int +) -> None: + """Register a hook on the schema and populate ``features.chem_features``.""" + await session.execute( + conventions_table.insert().values( + srn=f"urn:osa:localhost:conv:{HOOK}@1.0.0", + title="compound conv", + description=None, + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + file_requirements={}, + hooks=[{"name": HOOK}], + source=None, + created_at=datetime.now(UTC), + ) + ) + await session.commit() + feature_store = PostgresFeatureStore(engine, session) + await feature_store.create_table( + HOOK, + [ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ], + ) + await feature_store.insert_features( + HOOK, + str(record_srn), + [{"score": float(i), "label": f"l{i}"} for i in range(n)], + ) + + +@pytest.fixture +def client() -> AsyncClient: + from osa.application.api.rest.app import create_app + + app = create_app() + return AsyncClient(transport=ASGITransport(app=app), base_url="http://test") + + +@pytest.mark.asyncio +class TestDataRoutesE2E: + async def test_records_csv_gz_streams_full_table( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/records.csv.gz") + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/gzip") + text = gzip.decompress(resp.content).decode() + lines = [ln for ln in text.splitlines() if ln] + assert lines[0].split(",")[:5] == ["id", "srn", "schema_id", "version", "created_at"] + assert len(lines) == 1 + 3 # header + 3 rows + + async def test_records_json_default_page( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/records") + assert resp.status_code == 200 + body = resp.json() + assert len(body["rows"]) == 3 + assert body["next_cursor"] is None + assert body["rows"][0]["species"] == "Homo sapiens" + + async def test_records_json_cursor_pagination( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 5) + async with client: + page1 = (await client.get("/api/v1/data/compound@1.0.0/records?limit=2")).json() + assert len(page1["rows"]) == 2 + assert page1["next_cursor"] is not None + page2 = ( + await client.get( + f"/api/v1/data/compound@1.0.0/records?limit=2&cursor={page1['next_cursor']}" + ) + ).json() + ids1 = {r["id"] for r in page1["rows"]} + ids2 = {r["id"] for r in page2["rows"]} + assert ids1.isdisjoint(ids2) # no overlap across pages + + async def test_post_csv_with_filter( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 4) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.mw", + "op": "gte", + "value": 2.0, + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv", json=body) + assert resp.status_code == 200 + lines = [ln for ln in resp.text.splitlines() if ln] + assert len(lines) == 1 + 2 # mw in {2.0, 3.0} + + async def test_catalog_and_manifest( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + catalog = (await client.get("/api/v1/data")).json() + assert any(s["id"] == "compound" for s in catalog["schemas"]) + manifest = (await client.get("/api/v1/data/compound@1.0.0")).json() + assert manifest["id"] == "compound" + names = [t["name"] for t in manifest["table_resources"]] + assert "records" in names + + async def test_record_by_id_has_id_and_srn( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + resp = await client.get("/api/v1/data/records/rec000") + assert resp.status_code == 200 + body = resp.json() + assert body["id"] == "rec000" + assert body["srn"] == "urn:osa:localhost:rec:rec000@1" + + async def test_reserved_and_unknown_schema_404( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + assert (await client.get("/api/v1/data/records")).status_code == 404 + assert (await client.get("/api/v1/data/datasets")).status_code == 404 + assert (await client.get("/api/v1/data/nope@9.9.9")).status_code == 404 + + async def test_record_by_id_version_pin( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + store = PostgresMetadataStore(pg_engine, pg_session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(pg_session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + for version, mw in ((1, 10.0), (2, 20.0)): + srn = RecordSRN.parse(f"urn:osa:localhost:rec:dup@{version}") + await seed_record( + pg_engine, + srn=str(srn), + schema_id=SCHEMA.id.root, + schema_version=SCHEMA.version.root, + metadata={"species": "Homo sapiens", "mw": mw}, + published_at=datetime(2026, 1, version, tzinfo=UTC), + ) + await store.insert(SCHEMA, srn, {"species": "Homo sapiens", "mw": mw}) + await pg_session.commit() + + async with client: + v1 = (await client.get("/api/v1/data/records/dup@1")).json() + v2 = (await client.get("/api/v1/data/records/dup@2")).json() + latest = (await client.get("/api/v1/data/records/dup")).json() + assert v1["srn"] == "urn:osa:localhost:rec:dup@1" + assert v2["srn"] == "urn:osa:localhost:rec:dup@2" + # No version → latest published (@2, published 2026-01-02). + assert latest["srn"] == "urn:osa:localhost:rec:dup@2" + + +@pytest.mark.asyncio +class TestRecordsStreamingEdges: + async def test_post_csv_gz_happy_path( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 4) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.mw", + "op": "gte", + "value": 2.0, + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv.gz", json=body) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/gzip") + lines = [ln for ln in gzip.decompress(resp.content).decode().splitlines() if ln] + assert lines[0].split(",")[:5] == ["id", "srn", "schema_id", "version", "created_at"] + assert len(lines) == 1 + 2 # header + mw in {2.0, 3.0} + + async def test_empty_result_returns_gzipped_header_only( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 2) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.species", + "op": "eq", + "value": "Nonexistent species", + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv.gz", json=body) + assert resp.status_code == 200 + lines = [ln for ln in gzip.decompress(resp.content).decode().splitlines() if ln] + assert len(lines) == 1 # header only, no data rows + assert lines[0].split(",")[:5] == ["id", "srn", "schema_id", "version", "created_at"] + + async def test_invalid_filter_field_4xx_before_bytes( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 2) + body = { + "filter": { + "kind": "predicate", + "field": "metadata.nonexistent_field", + "op": "eq", + "value": "x", + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv.gz", json=body) + # The error must be surfaced before any streamed bytes (pre-flight). + assert 400 <= resp.status_code < 500 + assert not resp.headers.get("content-type", "").startswith("application/gzip") + + async def test_unknown_schema_404_before_bytes( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + resp = await client.get("/api/v1/data/nope@9.9.9/records.csv.gz") + assert resp.status_code == 404 + assert not resp.headers.get("content-type", "").startswith("application/gzip") + + +@pytest.mark.asyncio +class TestRecordsJsonEdges: + async def test_limit_over_max_is_clamped( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/records?limit=5000") + assert resp.status_code == 200 # clamped, not rejected + assert len(resp.json()["rows"]) == 3 + + async def test_sort_param_orders_rows( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 3) + async with client: + resp = await client.get( + "/api/v1/data/compound@1.0.0/records?sort=created_at:asc,id:asc" + ) + rows = resp.json()["rows"] + # created_at ascending → rec000 (earliest) first. + assert rows[0]["id"] == "rec000" + assert [r["id"] for r in rows] == sorted(r["id"] for r in rows) + + async def test_post_json_with_filter( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 4) + body = {"filter": {"kind": "predicate", "field": "metadata.mw", "op": "lt", "value": 2.0}} + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records", json=body) + assert resp.status_code == 200 + rows = resp.json()["rows"] + assert {r["id"] for r in rows} == {"rec000", "rec001"} # mw 0.0, 1.0 + + async def test_boolean_filter_composition( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + # Exercises AND / OR / NOT SQL compilation (incl. the NOT→coalesce path) + # that the relocated filter engine carries over from discovery. + await _seed(pg_engine, pg_session, 5) # mw ∈ {0,1,2,3,4} + body = { + "filter": { + "kind": "and", + "operands": [ + { + "kind": "or", + "operands": [ + { + "kind": "predicate", + "field": "metadata.mw", + "op": "gte", + "value": 3.0, + }, + {"kind": "predicate", "field": "metadata.mw", "op": "lt", "value": 1.0}, + ], + }, + { + "kind": "not", + "operand": { + "kind": "predicate", + "field": "metadata.mw", + "op": "eq", + "value": 4.0, + }, + }, + ], + } + } + async with client: + resp = await client.post("/api/v1/data/compound@1.0.0/records", json=body) + assert resp.status_code == 200 + rows = resp.json()["rows"] + # OR(mw≥3, mw<1) = {0,3,4}; AND NOT(mw=4) → {0,3}. + assert {r["id"] for r in rows} == {"rec000", "rec003"} + + +@pytest.mark.asyncio +class TestCatalogEdges: + async def test_empty_node_returns_200_empty_schemas(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/data") + assert resp.status_code == 200 + body = resp.json() + assert body["schemas"] == [] + assert "node_domain" in body + + +@pytest.mark.asyncio +class TestLegacySurfaceRemoved: + """US6: the legacy read surface is gone — every old path 404s.""" + + async def test_legacy_discovery_route_404(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/discovery/records") + assert resp.status_code == 404 + + async def test_legacy_records_srn_route_404(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/records/urn:osa:localhost:rec:anything@1") + assert resp.status_code == 404 + + async def test_legacy_search_route_404(self, client: AsyncClient): + async with client: + resp = await client.get("/api/v1/search/records?q=anything") + assert resp.status_code == 404 + + async def test_stats_has_no_index_field( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + async with client: + resp = await client.get("/api/v1/stats") + assert resp.status_code == 200 + body = resp.json() + assert "indexes" not in body + assert body["data_url"] == "/api/v1/data" + + +@pytest.mark.asyncio +class TestPostRateLimit: + async def test_eleventh_post_within_a_minute_is_429( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + from osa.application.api.v1.routes.data._limiter import limiter + + # The limiter is a module singleton shared across tests — reset its + # window so this test's count starts clean and doesn't leak outward. + limiter.reset() + await _seed(pg_engine, pg_session, 1) + statuses = [] + async with client: + for _ in range(11): # limit is 10/minute on POST routes + resp = await client.post("/api/v1/data/compound@1.0.0/records.csv", json={}) + statuses.append(resp.status_code) + limiter.reset() + assert statuses[:10] == [200] * 10 + assert statuses[10] == 429 + + +@pytest.mark.asyncio +class TestFeatureRoutesE2E: + async def test_feature_csv_gz_streams_full_table( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 3) + async with client: + resp = await client.get(f"/api/v1/data/compound@1.0.0/{HOOK}.csv.gz") + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/gzip") + lines = [ln for ln in gzip.decompress(resp.content).decode().splitlines() if ln] + assert lines[0].split(",")[:3] == ["id", "record_srn", "created_at"] + assert "score" in lines[0] and "label" in lines[0] + assert len(lines) == 1 + 3 # header + 3 rows + + async def test_feature_post_csv_with_filter( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 4) + body = { + "filter": { + "kind": "predicate", + "field": f"features.{HOOK}.score", + "op": "gte", + "value": 2.0, + } + } + async with client: + resp = await client.post(f"/api/v1/data/compound@1.0.0/{HOOK}.csv", json=body) + assert resp.status_code == 200 + lines = [ln for ln in resp.text.splitlines() if ln] + assert len(lines) == 1 + 2 # score in {2.0, 3.0} + + async def test_feature_json_default_page( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 3) + async with client: + resp = await client.get(f"/api/v1/data/compound@1.0.0/{HOOK}") + assert resp.status_code == 200 + body = resp.json() + assert len(body["rows"]) == 3 + assert body["rows"][0]["label"] == "l0" + + async def test_unknown_feature_404( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + async with client: + resp = await client.get("/api/v1/data/compound@1.0.0/no_such_feature.csv.gz") + assert resp.status_code == 404 + + async def test_manifest_lists_feature_table( + self, pg_engine: AsyncEngine, pg_session: AsyncSession, client: AsyncClient + ): + await _seed(pg_engine, pg_session, 1) + srn = RecordSRN.parse("urn:osa:localhost:rec:rec000@1") + await _seed_feature(pg_engine, pg_session, srn, 2) + async with client: + manifest = (await client.get("/api/v1/data/compound@1.0.0")).json() + feature = next(t for t in manifest["table_resources"] if t["name"] == HOOK) + assert feature["kind"] == "feature" + assert feature["row_count"] == 2 diff --git a/server/tests/integration/test_data_streaming_guarantees_postgres.py b/server/tests/integration/test_data_streaming_guarantees_postgres.py new file mode 100644 index 0000000..9b48e9e --- /dev/null +++ b/server/tests/integration/test_data_streaming_guarantees_postgres.py @@ -0,0 +1,153 @@ +"""Integration tests for the streaming engine's resource guarantees (T041, T042). + +These exercise the two non-functional promises of the ``/data/`` streaming path +directly against the read store + a real Postgres server-side cursor: + +* **Bounded memory (T041)** — streaming a large table must not materialize all + rows in the Python heap. Asserted with ``tracemalloc`` peak while draining the + dump without accumulating, against a bulk-seeded table. (The literal spec's + absolute "<50 MB RSS" is not meaningful in-process — the test runner's RSS is + already well above that from imports — so we assert *Python-heap peak stays + small*, which is the guarantee the streaming primitive actually provides.) + +* **Cursor release on early termination (T042)** — closing the async generator + mid-stream (the client-disconnect path: a ``GeneratorExit`` thrown into the + body) must run the ``try/finally`` that closes the server-side cursor, leaving + the connection immediately reusable. (Simulating a real TCP disconnect + + ``pg_stat_activity`` poll is not reliable through the in-process ASGI client, + so we drive the generator's lifecycle directly, which is where the guarantee + lives.) + +Skips automatically unless OSA_DATABASE__URL points at PostgreSQL. +""" + +import tracemalloc +from datetime import UTC, datetime + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + +from osa.domain.data.model.query_plan import PaginationParams, QueryPlan, TableKind +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.model.srn import Domain, SchemaId +from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore +from osa.infrastructure.persistence.repository.schema import ( + PostgresSemanticsSchemaRepository, +) + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _fields() -> list[FieldDefinition]: + return [ + FieldDefinition( + name="species", + type=FieldType.TEXT, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + FieldDefinition( + name="mw", + type=FieldType.NUMBER, + required=False, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + + +async def _setup_schema(engine: AsyncEngine, session: AsyncSession) -> None: + store = PostgresMetadataStore(engine, session) + await store.ensure_table(SCHEMA, _fields()) + await PostgresSemanticsSchemaRepository(session).save( + Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + ) + await session.commit() + + +async def _bulk_seed(engine: AsyncEngine, n: int) -> None: + """Insert *n* records + metadata rows in two set-based statements. + + ``source.id`` is per-row unique to satisfy the records source uniqueness + index; otherwise a single bulk INSERT would collide on the second row. + """ + async with engine.begin() as conn: + await conn.execute( + text( + """ + INSERT INTO records + (srn, convention_srn, schema_id, schema_version, source, metadata, published_at) + SELECT 'urn:osa:localhost:rec:bulk' || g || '@1', + 'urn:osa:localhost:conv:test@1.0.0', 'compound', '1.0.0', + jsonb_build_object('type', 'deposition', 'id', 'bulk' || g), + jsonb_build_object('species', 'Homo sapiens', 'mw', g), + now() + FROM generate_series(1, :n) g + """ + ), + {"n": n}, + ) + await conn.execute( + text( + """ + INSERT INTO metadata.compound_v1 (record_srn, species, mw) + SELECT 'urn:osa:localhost:rec:bulk' || g || '@1', 'Homo sapiens', g + FROM generate_series(1, :n) g + """ + ), + {"n": n}, + ) + + +def _records_plan(limit: int = 1000) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination=PaginationParams(limit=limit), + ) + + +@pytest.mark.asyncio +class TestStreamingGuarantees: + async def test_large_dump_does_not_materialize_in_python_heap( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + n = 20_000 + await _setup_schema(pg_engine, pg_session) + await _bulk_seed(pg_engine, n) + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + + tracemalloc.start() + count = 0 + # Drain WITHOUT accumulating — the engine must not hold all rows at once. + async for _row in rs.stream_rows(_records_plan()): + count += 1 + _current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + assert count == n + # If the engine buffered all 20K flattened dicts, peak would be many MB. + # A streaming server-side cursor keeps the Python-heap peak small. + assert peak < 8 * 1024 * 1024, f"peak heap {peak} bytes — streaming may be buffering" + + async def test_early_generator_close_releases_server_cursor( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await _setup_schema(pg_engine, pg_session) + await _bulk_seed(pg_engine, 500) + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + + gen = rs.stream_rows(_records_plan()) + first = await gen.__anext__() # opens the server-side cursor + assert first["schema_id"] == "compound@1.0.0" + # Simulate client disconnect: closing the generator throws GeneratorExit + # into the body, which must run the finally that closes the cursor. + await gen.aclose() + + # If the cursor leaked (portal still open), a follow-up query on the same + # connection would raise. A clean full drain proves the connection was + # returned to a usable state. + rows = [r async for r in rs.stream_rows(_records_plan())] + assert len(rows) == 500 diff --git a/server/tests/integration/test_discovery_compound_postgres.py b/server/tests/integration/test_discovery_compound_postgres.py deleted file mode 100644 index f752c0f..0000000 --- a/server/tests/integration/test_discovery_compound_postgres.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Integration tests for compound OR/NOT discovery filters against real PG.""" - -import pytest -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession - -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - FilterOperator, - Not, - Or, - Predicate, - SortOrder, -) -from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.persistence.adapter.discovery import PostgresDiscoveryReadStore -from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore - -from tests.integration.conftest import seed_record - -SCHEMA_V1 = SchemaId.parse("bio-sample@1.0.0") -FIELD_TYPES = {"species": FieldType.TEXT, "resolution": FieldType.NUMBER} - - -def _fields() -> list[FieldDefinition]: - return [ - FieldDefinition( - name="species", - type=FieldType.TEXT, - required=True, - cardinality=Cardinality.EXACTLY_ONE, - ), - FieldDefinition( - name="resolution", - type=FieldType.NUMBER, - required=False, - cardinality=Cardinality.EXACTLY_ONE, - ), - ] - - -@pytest.fixture -async def seeded_store(pg_engine: AsyncEngine, pg_session: AsyncSession) -> PostgresMetadataStore: - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - - repo = PostgresSemanticsSchemaRepository(pg_session) - await repo.save( - Schema(id=SCHEMA_V1, title="bio_sample", fields=_fields(), created_at=datetime.now(UTC)) - ) - - rows = [ - ("rec-a1", "Homo sapiens", 3.5), - ("rec-b1", "Homo sapiens", 1.0), - ("rec-c1", "Mus musculus", 3.5), - ("rec-d1", "Drosophila", 0.5), - ] - for rid, sp, res in rows: - srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") - await seed_record( - pg_engine, - srn=str(srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - ) - await store.insert(SCHEMA_V1, srn, {"species": sp, "resolution": res}) - - await pg_session.commit() - return store - - -def _pred(field: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=op, value=value) - - -@pytest.mark.asyncio -class TestCompound: - async def test_or_tree(self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_store): - read_store = PostgresDiscoveryReadStore(pg_session) - # species = Homo sapiens OR resolution < 1.0 - tree = Or( - operands=[ - _pred("species", FilterOperator.EQ, "Homo sapiens"), - _pred("resolution", FilterOperator.LT, 1.0), - ] - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - # a, b (species match) + d (resolution 0.5) — not c - assert srns == { - "urn:osa:localhost:rec:rec-a1@1", - "urn:osa:localhost:rec:rec-b1@1", - "urn:osa:localhost:rec:rec-d1@1", - } - - async def test_not_tree(self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_store): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Not(operand=_pred("species", FilterOperator.EQ, "Homo sapiens")) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-c1@1", "urn:osa:localhost:rec:rec-d1@1"} - - async def test_nested_and_or( - self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_store - ): - read_store = PostgresDiscoveryReadStore(pg_session) - # resolution >= 3.0 AND (species = Homo sapiens OR species = Mus musculus) - tree = And( - operands=[ - _pred("resolution", FilterOperator.GTE, 3.0), - Or( - operands=[ - _pred("species", FilterOperator.EQ, "Homo sapiens"), - _pred("species", FilterOperator.EQ, "Mus musculus"), - ] - ), - ] - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-a1@1", "urn:osa:localhost:rec:rec-c1@1"} diff --git a/server/tests/integration/test_discovery_cross_join_postgres.py b/server/tests/integration/test_discovery_cross_join_postgres.py deleted file mode 100644 index c2ed9ec..0000000 --- a/server/tests/integration/test_discovery_cross_join_postgres.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Integration tests for cross-domain JOINs between records ⋈ features in discovery.""" - -import pytest -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession - -from osa.domain.discovery.model.refs import FeatureFieldRef, MetadataFieldRef -from osa.domain.discovery.model.value import And, FilterOperator, Not, Predicate, SortOrder -from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.hook import ColumnDef -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.persistence.adapter.discovery import PostgresDiscoveryReadStore -from osa.infrastructure.persistence.feature_store import PostgresFeatureStore -from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore - -from tests.integration.conftest import seed_record - -SCHEMA_V1 = SchemaId.parse("bio-sample@1.0.0") -FIELD_TYPES = {"species": FieldType.TEXT} - - -def _metadata_fields() -> list[FieldDefinition]: - return [ - FieldDefinition( - name="species", - type=FieldType.TEXT, - required=True, - cardinality=Cardinality.EXACTLY_ONE, - ), - ] - - -def _feature_columns() -> list[ColumnDef]: - return [ - ColumnDef(name="confidence", json_type="number", required=True), - ] - - -@pytest.fixture -async def seeded_both(pg_engine: AsyncEngine, pg_session: AsyncSession): - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - mstore = PostgresMetadataStore(pg_engine, pg_session) - await mstore.ensure_table(SCHEMA_V1, _metadata_fields()) - - fstore = PostgresFeatureStore(pg_engine, pg_session) - await fstore.create_table("cell_classifier", _feature_columns()) - - repo = PostgresSemanticsSchemaRepository(pg_session) - await repo.save( - Schema( - id=SCHEMA_V1, - title="bio_sample", - fields=_metadata_fields(), - created_at=datetime.now(UTC), - ) - ) - - # r1: Homo sapiens + confidence 0.95 - # r2: Homo sapiens + confidence 0.5 - # r3: Mus musculus + confidence 0.95 - for rid, sp, conf in [ - ("rec-r1", "Homo sapiens", 0.95), - ("rec-r2", "Homo sapiens", 0.5), - ("rec-r3", "Mus musculus", 0.95), - ]: - srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") - await seed_record( - pg_engine, - srn=str(srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - ) - await mstore.insert(SCHEMA_V1, srn, {"species": sp}) - await fstore.insert_features("cell_classifier", str(srn), [{"confidence": conf}]) - - await pg_session.commit() - return mstore, fstore - - -@pytest.mark.asyncio -class TestCrossDomainJoin: - async def test_joined_intersection( - self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_both - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = And( - operands=[ - Predicate( - field=MetadataFieldRef(field="species"), - op=FilterOperator.EQ, - value="Homo sapiens", - ), - Predicate( - field=FeatureFieldRef(hook="cell_classifier", column="confidence"), - op=FilterOperator.GT, - value=0.9, - ), - ] - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-r1@1"} - - async def test_unknown_hook_raises( - self, pg_engine: AsyncEngine, pg_session: AsyncSession, seeded_both - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Predicate( - field=FeatureFieldRef(hook="does_not_exist", column="anything"), - op=FilterOperator.EQ, - value=1, - ) - with pytest.raises(ValidationError, match="Unknown feature hook"): - await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - - -@pytest.fixture -async def seeded_with_missing_feature_row(pg_engine: AsyncEngine, pg_session: AsyncSession): - """Seed a record with a metadata row but NO feature row, so the outer - join produces NULL feature columns for that record.""" - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - mstore = PostgresMetadataStore(pg_engine, pg_session) - await mstore.ensure_table(SCHEMA_V1, _metadata_fields()) - - fstore = PostgresFeatureStore(pg_engine, pg_session) - await fstore.create_table("cell_classifier", _feature_columns()) - - repo = PostgresSemanticsSchemaRepository(pg_session) - await repo.save( - Schema( - id=SCHEMA_V1, - title="bio_sample", - fields=_metadata_fields(), - created_at=datetime.now(UTC), - ) - ) - - # rec-has-feature: has a feature row with confidence 0.95. - # rec-no-feature: no feature row at all (outer join will produce NULLs). - for rid, sp in [("rec-has-feature", "Homo sapiens"), ("rec-no-feature", "Mus musculus")]: - srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") - await seed_record( - pg_engine, - srn=str(srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - ) - await mstore.insert(SCHEMA_V1, srn, {"species": sp}) - - has_feature_srn = "urn:osa:localhost:rec:rec-has-feature@1" - await fstore.insert_features("cell_classifier", has_feature_srn, [{"confidence": 0.95}]) - - await pg_session.commit() - - -@pytest.mark.asyncio -class TestOuterJoinNullHandling: - """Records without a feature row must not be silently dropped from NEQ/NOT - predicates on feature columns — the outer join produces NULL, and naive - SQL three-valued logic would exclude them.""" - - async def test_neq_on_feature_column_includes_missing_rows( - self, - pg_engine: AsyncEngine, - pg_session: AsyncSession, - seeded_with_missing_feature_row, - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Predicate( - field=FeatureFieldRef(hook="cell_classifier", column="confidence"), - op=FilterOperator.NEQ, - value=0.95, - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - # rec-no-feature has no feature row → confidence is NULL → "!= 0.95" - # should include it. rec-has-feature has confidence 0.95 → excluded. - assert srns == {"urn:osa:localhost:rec:rec-no-feature@1"} - - async def test_not_on_feature_column_includes_missing_rows( - self, - pg_engine: AsyncEngine, - pg_session: AsyncSession, - seeded_with_missing_feature_row, - ): - read_store = PostgresDiscoveryReadStore(pg_session) - tree = Not( - operand=Predicate( - field=FeatureFieldRef(hook="cell_classifier", column="confidence"), - op=FilterOperator.EQ, - value=0.95, - ) - ) - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types=FIELD_TYPES, - ) - srns = {str(r.srn) for r in results} - # Same invariant as NEQ: NOT(confidence = 0.95) must surface the - # record with a missing feature row. - assert srns == {"urn:osa:localhost:rec:rec-no-feature@1"} diff --git a/server/tests/integration/test_discovery_records_typed_and.py b/server/tests/integration/test_discovery_records_typed_and.py deleted file mode 100644 index b9487c5..0000000 --- a/server/tests/integration/test_discovery_records_typed_and.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Integration tests for /discovery/records with typed-table AND filters.""" - -import pytest -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession - -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import And, FilterOperator, Predicate, SortOrder -from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.persistence.adapter.discovery import ( - PostgresDiscoveryReadStore, - PostgresFieldDefinitionReader, -) -from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore - -from tests.integration.conftest import seed_record - -SCHEMA_V1 = SchemaId.parse("bio-sample@1.0.0") - - -def _fields() -> list[FieldDefinition]: - return [ - FieldDefinition( - name="species", - type=FieldType.TEXT, - required=True, - cardinality=Cardinality.EXACTLY_ONE, - ), - FieldDefinition( - name="resolution", - type=FieldType.NUMBER, - required=False, - cardinality=Cardinality.EXACTLY_ONE, - ), - FieldDefinition( - name="method", - type=FieldType.TEXT, - required=False, - cardinality=Cardinality.EXACTLY_ONE, - ), - ] - - -async def _seed_schema_row(session: AsyncSession) -> None: - """Seed the `schemas` row so the discovery field reader can resolve types.""" - from datetime import UTC, datetime - - from osa.domain.semantics.model.schema import Schema - from osa.infrastructure.persistence.repository.schema import ( - PostgresSemanticsSchemaRepository, - ) - - repo = PostgresSemanticsSchemaRepository(session) - await repo.save( - Schema(id=SCHEMA_V1, title="bio_sample", fields=_fields(), created_at=datetime.now(UTC)) - ) - - -async def _publish( - engine: AsyncEngine, - session: AsyncSession, - store: PostgresMetadataStore, - record_srn: RecordSRN, - species: str, - resolution: float, - method: str, -) -> None: - await seed_record( - engine, - srn=str(record_srn), - schema_id=SCHEMA_V1.id.root, - schema_version=SCHEMA_V1.version.root, - metadata={"species": species, "resolution": resolution, "method": method}, - ) - await store.insert( - SCHEMA_V1, - record_srn, - {"species": species, "resolution": resolution, "method": method}, - ) - - -@pytest.mark.asyncio -class TestDiscoveryTypedAnd: - async def test_and_filter_returns_matching_records( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - - rows = [ - ("rec-r1", "Homo sapiens", 3.5, "cryo-EM"), - ("rec-r2", "Homo sapiens", 1.8, "X-ray"), - ("rec-r3", "Mus musculus", 3.0, "cryo-EM"), - ] - for rid, sp, res, meth in rows: - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1"), - sp, - res, - meth, - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - tree = And( - operands=[ - Predicate( - field=MetadataFieldRef(field="species"), - op=FilterOperator.EQ, - value="Homo sapiens", - ), - Predicate( - field=MetadataFieldRef(field="resolution"), - op=FilterOperator.GTE, - value=2.0, - ), - ] - ) - - results = await read_store.search_records( - filter_expr=tree, - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={ - "species": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TEXT, - }, - ) - - srns = {str(r.srn) for r in results} - assert srns == {"urn:osa:localhost:rec:rec-r1@1"} - - async def test_scalar_op_succeeds_on_unindexed_column( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - """FR-020: scalar ops must NOT be rejected for lack of index.""" - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse("urn:osa:localhost:rec:rec-ra@1"), - "Homo sapiens", - 3.5, - "cryo-EM", - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - results = await read_store.search_records( - filter_expr=Predicate( - field=MetadataFieldRef(field="method"), - op=FilterOperator.CONTAINS, - value="cryo", - ), - schema_id=SCHEMA_V1, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={ - "species": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TEXT, - }, - ) - assert len(results) == 1 - - -@pytest.mark.asyncio -class TestUnscopedListing: - """Plain listings without a filter return canonical JSONB metadata. - Metadata-filtered queries require a pinned schema — the typed table is - the only filter path.""" - - async def test_unscoped_predicate_filter_raises_without_schema( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - """Filtering by a metadata field without schema_id must raise — - the JSONB fallback compile path was removed.""" - from osa.domain.discovery.model.refs import MetadataFieldRef - from osa.domain.discovery.model.value import FilterOperator, Predicate - from osa.domain.shared.error import ValidationError - - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse("urn:osa:localhost:rec:rec-9x1w@1"), - "Homo sapiens", - 3.5, - "cryo-EM", - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - with pytest.raises(ValidationError) as exc: - await read_store.search_records( - filter_expr=Predicate( - field=MetadataFieldRef(field="species"), - op=FilterOperator.EQ, - value="Homo sapiens", - ), - schema_id=None, - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={"species": FieldType.TEXT}, - ) - assert exc.value.code == "schema_required_for_metadata_query" - - async def test_unscoped_listing_returns_jsonb_metadata( - self, pg_engine: AsyncEngine, pg_session: AsyncSession - ): - store = PostgresMetadataStore(pg_engine, pg_session) - await store.ensure_table(SCHEMA_V1, _fields()) - await _seed_schema_row(pg_session) - - await _publish( - pg_engine, - pg_session, - store, - RecordSRN.parse("urn:osa:localhost:rec:rec-unscoped@1"), - "Homo sapiens", - 3.5, - "cryo-EM", - ) - await pg_session.commit() - - read_store = PostgresDiscoveryReadStore(pg_session) - results = await read_store.search_records( - filter_expr=None, - schema_id=None, # deliberately unscoped — exercises the JSONB path - convention_srn=None, - text_fields=[], - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - field_types={}, - ) - assert len(results) == 1 - assert results[0].metadata == { - "species": "Homo sapiens", - "resolution": 3.5, - "method": "cryo-EM", - } - - -@pytest.mark.asyncio -class TestFieldDefinitionReader: - async def test_get_fields_for_schema(self, pg_engine: AsyncEngine, pg_session: AsyncSession): - await _seed_schema_row(pg_session) - await pg_session.commit() - - reader = PostgresFieldDefinitionReader(pg_session) - fields = await reader.get_fields_for_schema(SCHEMA_V1) - assert fields["species"] == FieldType.TEXT - assert fields["resolution"] == FieldType.NUMBER diff --git a/server/tests/integration/test_event_batch_processing.py b/server/tests/integration/test_event_batch_processing.py deleted file mode 100644 index e1f6754..0000000 --- a/server/tests/integration/test_event_batch_processing.py +++ /dev/null @@ -1,229 +0,0 @@ -"""Integration tests for batch event processing flow. - -Tests the end-to-end flow from RecordPublished -> IndexRecord fan-out -> batch processing. -""" - -from typing import Any -from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends -from osa.domain.index.handler.vector_index_handler import VectorIndexHandler -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import ( - ConventionSRN, - DepositionSRN, - Domain, - LocalId, - RecordSRN, - RecordVersion, -) - - -class FakeBackend: - """Fake storage backend that tracks batch calls.""" - - def __init__(self, name: str): - self._name = name - self.batch_calls: list[list[tuple[str, dict]]] = [] - self.total_records: int = 0 - - @property - def name(self) -> str: - return self._name - - async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - """Track batch calls for verification.""" - self.batch_calls.append(list(records)) - self.total_records += len(records) - - -class FakeOutbox: - """Fake outbox that collects emitted events.""" - - def __init__(self): - self.events: list[Any] = [] - - async def append(self, event: Any) -> None: - self.events.append(event) - - -def make_record_published( - record_id: str | None = None, - metadata: dict | None = None, -) -> RecordPublished: - """Create a RecordPublished event for testing.""" - from osa.domain.shared.model.source import DepositionSource - - dep_srn = DepositionSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - ) - from osa.domain.shared.model.srn import SchemaId - - return RecordPublished( - id=EventId(uuid4()), - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(record_id or str(uuid4())), - version=RecordVersion(1), - ), - source=DepositionSource(id=str(dep_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=SchemaId.parse("test@1.0.0"), - metadata=metadata or {"title": "Test Record"}, - ) - - -class TestBatchEventProcessingFlow: - """Integration tests for the batch event processing flow.""" - - @pytest.mark.asyncio - async def test_fanout_creates_index_records_per_backend(self): - """FanOutToIndexBackends should create one IndexRecord per backend.""" - # Arrange - backends = { - "vector": FakeBackend("vector"), - "keyword": FakeBackend("keyword"), - } - registry = IndexRegistry(backends) - outbox = FakeOutbox() - - fanout = FanOutToIndexBackends(indexes=registry, outbox=outbox) - event = make_record_published() - - # Act - await fanout.handle(event) - - # Assert - assert len(outbox.events) == 2 - backend_names = {e.backend_name for e in outbox.events} - assert backend_names == {"vector", "keyword"} - - @pytest.mark.asyncio - async def test_handler_processes_batch(self): - """VectorIndexHandler should process events in batches.""" - # Arrange - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - handler = VectorIndexHandler(indexes=registry) - - # Create events for vector backend - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"id": i}, - ) - for i in range(5) - ] - - # Act - await handler.handle_batch(events) - - # Assert - vector backend received 5 records in single batch call - assert len(vector_backend.batch_calls) == 1 - assert len(vector_backend.batch_calls[0]) == 5 - assert vector_backend.total_records == 5 - - @pytest.mark.asyncio - async def test_end_to_end_fanout_to_batch(self): - """End-to-end: RecordPublished -> FanOut -> BatchProcess.""" - # Arrange - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - outbox = FakeOutbox() - - fanout = FanOutToIndexBackends(indexes=registry, outbox=outbox) - handler = VectorIndexHandler(indexes=registry) - - # Create multiple RecordPublished events - num_records = 10 - published_events = [make_record_published() for _ in range(num_records)] - - # Act - Step 1: FanOut each RecordPublished - for event in published_events: - await fanout.handle(event) - - # Verify IndexRecord events were created - assert len(outbox.events) == num_records - assert all(isinstance(e, IndexRecord) for e in outbox.events) - assert all(e.backend_name == "vector" for e in outbox.events) - - # Act - Step 2: Batch process all IndexRecord events - await handler.handle_batch(outbox.events) - - # Assert - All records indexed in single batch call - assert len(vector_backend.batch_calls) == 1 - assert vector_backend.total_records == num_records - - @pytest.mark.asyncio - async def test_batch_efficiency_large_batch(self): - """Verify batch processing is efficient with large batches.""" - # Arrange - backend = FakeBackend("vector") - registry = IndexRegistry({"vector": backend}) - handler = VectorIndexHandler(indexes=registry) - - # Create 1000+ events - num_events = 1000 - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"index": i, "title": f"Record {i}"}, - ) - for i in range(num_events) - ] - - # Act - await handler.handle_batch(events) - - # Assert - All records in single batch call (not 1000 individual calls) - assert len(backend.batch_calls) == 1, "Should use single batch call, not individual calls" - assert backend.total_records == num_events - assert len(backend.batch_calls[0]) == num_events - - @pytest.mark.asyncio - async def test_failure_propagates_from_handler(self): - """Verify failures propagate from handler (Worker handles retry).""" - # Arrange - VectorIndexHandler looks up "vector" backend specifically - vector_backend = MagicMock() - vector_backend.name = "vector" - vector_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend failure")) - - registry = IndexRegistry({"vector": vector_backend}) - handler = VectorIndexHandler(indexes=registry) - - events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="vector", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"id": i}, - ) - for i in range(2) - ] - - # Process events - should raise (Worker handles retry) - with pytest.raises(Exception, match="Backend failure"): - await handler.handle_batch(events) diff --git a/server/osa/domain/export/model/__init__.py b/server/tests/unit/application/api/v1/routes/__init__.py similarity index 100% rename from server/osa/domain/export/model/__init__.py rename to server/tests/unit/application/api/v1/routes/__init__.py diff --git a/server/osa/domain/export/port/__init__.py b/server/tests/unit/application/api/v1/routes/data/__init__.py similarity index 100% rename from server/osa/domain/export/port/__init__.py rename to server/tests/unit/application/api/v1/routes/data/__init__.py diff --git a/server/tests/unit/application/api/v1/routes/data/test_params.py b/server/tests/unit/application/api/v1/routes/data/test_params.py new file mode 100644 index 0000000..7895a92 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/test_params.py @@ -0,0 +1,64 @@ +"""Unit tests for table-route request parsing (sort spec + plan build).""" + +import pytest + +from osa.application.api.v1.routes.data._params import build_plan, parse_sort +from osa.domain.data.model.query_plan import SortDirection, TableKind +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def test_parse_sort_empty_is_empty_list() -> None: + assert parse_sort(None) == [] + assert parse_sort("") == [] + + +def test_parse_sort_multi_with_directions() -> None: + specs = parse_sort("created_at:desc,id:desc") + assert [(s.column, s.direction) for s in specs] == [ + ("created_at", SortDirection.DESC), + ("id", SortDirection.DESC), + ] + + +def test_parse_sort_bare_column_defaults_asc() -> None: + specs = parse_sort("mw") + assert specs[0].column == "mw" + assert specs[0].direction == SortDirection.ASC + + +def test_parse_sort_invalid_direction_raises() -> None: + with pytest.raises(ValidationError): + parse_sort("mw:sideways") + + +def test_build_plan_wraps_cursor_and_limit() -> None: + plan = build_plan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + feature_name=None, + filter_expr=None, + cursor="CUR", + limit=10, + sort="mw:asc", + ) + assert plan.pagination.limit == 10 + assert str(plan.pagination.cursor) == "CUR" + assert plan.sort[0].column == "mw" + + +def test_build_plan_no_cursor_is_none() -> None: + plan = build_plan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + feature_name=None, + filter_expr=None, + cursor=None, + limit=50, + sort=None, + ) + assert plan.pagination.cursor is None + # default RECORDS sort applied + assert plan.sort[0].column == "created_at" diff --git a/server/tests/unit/application/api/v1/routes/data/test_streaming.py b/server/tests/unit/application/api/v1/routes/data/test_streaming.py new file mode 100644 index 0000000..b1e2c97 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/test_streaming.py @@ -0,0 +1,103 @@ +"""Unit tests for build_table_response — pre-flight, streaming, pagination, cursor.""" + +import base64 +import csv +import gzip +import io +import json +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.application.api.v1.routes.data._streaming import build_table_response +from osa.domain.data.model.format import FORMATS +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") +COLUMNS = [ColumnSpec(name="id", type=FieldType.TEXT), ColumnSpec(name="srn", type=FieldType.TEXT)] +JSON_FMT = next(f for f in FORMATS if f.suffix == "") +CSV_FMT = next(f for f in FORMATS if f.suffix == "csv") +GZ_FMT = next(f for f in FORMATS if f.suffix == "csv.gz") + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _raising() -> AsyncIterator[Mapping[str, Any]]: + raise ValueError("boom before first row") + yield # pragma: no cover + + +async def _body(resp) -> bytes: + out = b"" + async for chunk in resp.body_iterator: + out += chunk.encode() if isinstance(chunk, str) else chunk + return out + + +def _plan(limit: int = 50) -> QueryPlan: + return QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, pagination={"limit": limit}) + + +@pytest.mark.asyncio +async def test_streaming_preflight_raises_before_bytes() -> None: + # An error on the first __anext__ must propagate (→ 4xx), not corrupt a stream. + with pytest.raises(ValueError, match="boom"): + await build_table_response(_raising(), CSV_FMT, COLUMNS, _plan()) + + +@pytest.mark.asyncio +async def test_streaming_csv_full_table() -> None: + rows = [{"id": "a", "srn": "s1"}, {"id": "b", "srn": "s2"}] + resp = await build_table_response(_aiter(rows), CSV_FMT, COLUMNS, _plan()) + parsed = list(csv.reader(io.StringIO((await _body(resp)).decode()))) + assert parsed[0] == ["id", "srn"] + assert parsed[1:] == [["a", "s1"], ["b", "s2"]] + + +@pytest.mark.asyncio +async def test_streaming_gzip_roundtrip() -> None: + rows = [{"id": "a", "srn": "s1"}] + resp = await build_table_response(_aiter(rows), GZ_FMT, COLUMNS, _plan()) + text = gzip.decompress(await _body(resp)).decode() + assert "id,srn" in text and "a,s1" in text + + +@pytest.mark.asyncio +async def test_streaming_empty_is_header_only() -> None: + resp = await build_table_response(_aiter([]), CSV_FMT, COLUMNS, _plan()) + parsed = list(csv.reader(io.StringIO((await _body(resp)).decode()))) + assert parsed == [["id", "srn"]] + + +@pytest.mark.asyncio +async def test_paginated_no_more_rows_null_cursor() -> None: + rows = [{"id": "a", "srn": "s1"}, {"id": "b", "srn": "s2"}] + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, _plan(limit=50)) + parsed = json.loads(await _body(resp)) + assert parsed["next_cursor"] is None + assert parsed["rows"] == [{"id": "a", "srn": "s1"}, {"id": "b", "srn": "s2"}] + + +@pytest.mark.asyncio +async def test_paginated_has_more_emits_cursor() -> None: + # limit=2, but 3 rows available → has_more, cursor derived from 2nd row's srn. + rows = [ + {"id": "a", "srn": "s1", "created_at": "2026-01-03"}, + {"id": "b", "srn": "s2", "created_at": "2026-01-02"}, + {"id": "c", "srn": "s3", "created_at": "2026-01-01"}, + ] + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, _plan(limit=2)) + parsed = json.loads(await _body(resp)) + assert len(parsed["rows"]) == 2 + assert parsed["next_cursor"] is not None + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + # default RECORDS sort is created_at desc, tiebreaker srn + assert decoded["s"] == "2026-01-02" + assert decoded["id"] == "s2" diff --git a/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py b/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py new file mode 100644 index 0000000..2d2e2f5 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py @@ -0,0 +1,87 @@ +"""T033 — register_table_routes registers exactly 6 routes with stable op IDs.""" + +import pytest +from fastapi import APIRouter + +from osa.application.api.v1.routes.data.tables import ( + format_key, + path_for, + register_table_routes, +) +from osa.domain.data.model.format import FORMATS + + +def _noop_builder(_fmt): + async def endpoint() -> dict: # pragma: no cover - never called in this test + return {} + + return endpoint + + +def _routes(router: APIRouter): + return [r for r in router.routes if getattr(r, "operation_id", None)] + + +def test_registers_six_routes() -> None: + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + assert len(_routes(router)) == 6 + + +def test_stable_operation_ids_for_records() -> None: + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + ids = {r.operation_id for r in _routes(router)} + assert ids == { + "records_get_json", + "records_post_json", + "records_get_csv", + "records_post_csv", + "records_get_csv_gz", + "records_post_csv_gz", + } + + +def test_stable_operation_ids_for_feature() -> None: + router = APIRouter() + register_table_routes(router, "/{schema}/{feature}", _noop_builder, _noop_builder, "feature") + ids = {r.operation_id for r in _routes(router)} + assert ids == { + "feature_get_json", + "feature_post_json", + "feature_get_csv", + "feature_post_csv", + "feature_get_csv_gz", + "feature_post_csv_gz", + } + + +def test_duplicate_resource_on_same_router_raises() -> None: + # T114: a second registration with the same resource_name on the same router + # would mint colliding operation IDs — caught loudly, not silently shipped. + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + with pytest.raises(ValueError, match="Duplicate operation_id"): + register_table_routes(router, "/{schema}/other", _noop_builder, _noop_builder, "records") + + +def test_distinct_resources_on_same_router_coexist() -> None: + # records + feature share one router in the real app — no false positive. + router = APIRouter() + register_table_routes(router, "/{schema}/records", _noop_builder, _noop_builder, "records") + register_table_routes(router, "/{schema}/{feature}", _noop_builder, _noop_builder, "feature") + assert len(_routes(router)) == 12 + + +def test_paths_use_suffix() -> None: + json_fmt = next(f for f in FORMATS if f.suffix == "") + csv_fmt = next(f for f in FORMATS if f.suffix == "csv") + gz_fmt = next(f for f in FORMATS if f.suffix == "csv.gz") + assert path_for("/{schema}/records", json_fmt) == "/{schema}/records" + assert path_for("/{schema}/records", csv_fmt) == "/{schema}/records.csv" + assert path_for("/{schema}/records", gz_fmt) == "/{schema}/records.csv.gz" + + +def test_format_key_mapping() -> None: + keys = {format_key(f) for f in FORMATS} + assert keys == {"json", "csv", "csv_gz"} diff --git a/server/osa/domain/export/query/__init__.py b/server/tests/unit/domain/data/__init__.py similarity index 100% rename from server/osa/domain/export/query/__init__.py rename to server/tests/unit/domain/data/__init__.py diff --git a/server/osa/domain/export/service/__init__.py b/server/tests/unit/domain/data/serializer/__init__.py similarity index 100% rename from server/osa/domain/export/service/__init__.py rename to server/tests/unit/domain/data/serializer/__init__.py diff --git a/server/tests/unit/domain/data/serializer/test_csv.py b/server/tests/unit/domain/data/serializer/test_csv.py new file mode 100644 index 0000000..b3fdafb --- /dev/null +++ b/server/tests/unit/domain/data/serializer/test_csv.py @@ -0,0 +1,64 @@ +"""T031 — CsvSerializer header row, quoting, empty result.""" + +import csv +import io +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.serializer.csv import CsvSerializer +from osa.domain.semantics.model.value import FieldType + +COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="name", type=FieldType.TEXT), +] + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _collect(gen: AsyncIterator[bytes]) -> str: + out = b"" + async for chunk in gen: + out += chunk + return out.decode() + + +@pytest.mark.asyncio +async def test_csv_header_and_rows() -> None: + rows = [{"id": "a", "name": "alpha"}, {"id": "b", "name": "beta"}] + text = await _collect(CsvSerializer().stream(_aiter(rows), COLUMNS)) + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[0] == ["id", "name"] + assert parsed[1] == ["a", "alpha"] + assert parsed[2] == ["b", "beta"] + + +@pytest.mark.asyncio +async def test_csv_empty_result_is_header_only() -> None: + text = await _collect(CsvSerializer().stream(_aiter([]), COLUMNS)) + parsed = list(csv.reader(io.StringIO(text))) + assert parsed == [["id", "name"]] + + +@pytest.mark.asyncio +async def test_csv_quotes_values_with_commas() -> None: + rows = [{"id": "a", "name": "beta, gamma"}] + text = await _collect(CsvSerializer().stream(_aiter(rows), COLUMNS)) + # QUOTE_MINIMAL wraps the comma-bearing value + assert '"beta, gamma"' in text + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[1] == ["a", "beta, gamma"] + + +@pytest.mark.asyncio +async def test_csv_missing_column_is_empty() -> None: + rows = [{"id": "a"}] # no "name" + text = await _collect(CsvSerializer().stream(_aiter(rows), COLUMNS)) + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[1] == ["a", ""] diff --git a/server/tests/unit/domain/data/serializer/test_csv_gzip.py b/server/tests/unit/domain/data/serializer/test_csv_gzip.py new file mode 100644 index 0000000..77f983f --- /dev/null +++ b/server/tests/unit/domain/data/serializer/test_csv_gzip.py @@ -0,0 +1,56 @@ +"""T032 — CsvGzipSerializer gzip-stream round-trip and magic bytes.""" + +import csv +import gzip +import io +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.serializer.csv_gzip import CsvGzipSerializer +from osa.domain.semantics.model.value import FieldType + +COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="name", type=FieldType.TEXT), +] + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _collect(gen: AsyncIterator[bytes]) -> bytes: + out = b"" + async for chunk in gen: + out += chunk + return out + + +@pytest.mark.asyncio +async def test_gzip_magic_bytes() -> None: + body = await _collect(CsvGzipSerializer().stream(_aiter([{"id": "a", "name": "x"}]), COLUMNS)) + # gzip stream header: 0x1f 0x8b 0x08 + assert body[:3] == b"\x1f\x8b\x08" + + +@pytest.mark.asyncio +async def test_gzip_roundtrip_through_gunzip() -> None: + rows = [{"id": "a", "name": "alpha"}, {"id": "b", "name": "beta"}] + body = await _collect(CsvGzipSerializer().stream(_aiter(rows), COLUMNS)) + text = gzip.decompress(body).decode() + parsed = list(csv.reader(io.StringIO(text))) + assert parsed[0] == ["id", "name"] + assert parsed[1] == ["a", "alpha"] + assert parsed[2] == ["b", "beta"] + + +@pytest.mark.asyncio +async def test_gzip_empty_result_roundtrips_to_header() -> None: + body = await _collect(CsvGzipSerializer().stream(_aiter([]), COLUMNS)) + text = gzip.decompress(body).decode() + parsed = list(csv.reader(io.StringIO(text))) + assert parsed == [["id", "name"]] diff --git a/server/tests/unit/domain/data/serializer/test_json.py b/server/tests/unit/domain/data/serializer/test_json.py new file mode 100644 index 0000000..33eed01 --- /dev/null +++ b/server/tests/unit/domain/data/serializer/test_json.py @@ -0,0 +1,53 @@ +"""T030 — JsonSerializer paginated envelope, empty result, cursor embedding.""" + +import json +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.serializer.json import JsonSerializer +from osa.domain.semantics.model.value import FieldType + +COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="mw", type=FieldType.NUMBER), +] + + +async def _aiter(rows: list[Mapping[str, Any]]) -> AsyncIterator[Mapping[str, Any]]: + for r in rows: + yield r + + +async def _collect(gen: AsyncIterator[bytes]) -> bytes: + out = b"" + async for chunk in gen: + out += chunk + return out + + +@pytest.mark.asyncio +async def test_json_envelope_with_rows_and_cursor() -> None: + rows = [{"id": "a", "mw": 1.5}, {"id": "b", "mw": 2.0}] + body = await _collect(JsonSerializer().stream(_aiter(rows), COLUMNS, next_cursor="CURSOR")) + parsed = json.loads(body) + assert parsed["next_cursor"] == "CURSOR" + assert parsed["rows"] == [{"id": "a", "mw": 1.5}, {"id": "b", "mw": 2.0}] + + +@pytest.mark.asyncio +async def test_json_empty_result_is_valid() -> None: + body = await _collect(JsonSerializer().stream(_aiter([]), COLUMNS, next_cursor=None)) + parsed = json.loads(body) + assert parsed == {"rows": [], "next_cursor": None} + + +@pytest.mark.asyncio +async def test_json_projects_only_declared_columns() -> None: + rows = [{"id": "a", "mw": 1.5, "secret": "x"}] + body = await _collect(JsonSerializer().stream(_aiter(rows), COLUMNS)) + parsed = json.loads(body) + assert parsed["rows"] == [{"id": "a", "mw": 1.5}] + assert parsed["next_cursor"] is None diff --git a/server/tests/unit/domain/data/test_data_query_service.py b/server/tests/unit/domain/data/test_data_query_service.py new file mode 100644 index 0000000..ba35c05 --- /dev/null +++ b/server/tests/unit/domain/data/test_data_query_service.py @@ -0,0 +1,112 @@ +"""Unit tests for DataQueryService — filter-bounds validation + delegation. + +Uses a fake DataReadStore (no DB), mirroring the discovery service tests. +""" + +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass +from typing import Any + +import pytest + +from osa.domain.data.model.filter import And, FilterOperator, Predicate +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.service.data_query import DataQueryService +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +@dataclass +class FakeConfig: + """Only the filter-bound knobs DataQueryService reads (avoids full env setup).""" + + discovery_max_filter_depth: int = 10 + discovery_max_predicates: int = 200 + discovery_max_cross_domain_joins: int = 10 + + +class FakeReadStore: + def __init__(self, rows: list[Mapping[str, Any]]) -> None: + self._rows = rows + self.received_plan: QueryPlan | None = None + + async def stream_rows(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + self.received_plan = plan + for row in self._rows: + yield row + + async def get_record_by_id(self, id, version): # pragma: no cover + return None + + async def get_node_catalog(self): # pragma: no cover + ... + + async def get_schema_manifest(self, schema_id): # pragma: no cover + return None + + async def get_latest_schema_id(self, schema_short_id): # pragma: no cover + return None + + +def _service(rows=None) -> tuple[DataQueryService, FakeReadStore]: + store = FakeReadStore(rows or []) + return DataQueryService(read_store=store, config=FakeConfig()), store + + +def _pred(field: str = "metadata.mw") -> Predicate: + return Predicate(field=field, op=FilterOperator.EQ, value=1) + + +async def _drain(gen: AsyncIterator[Mapping[str, Any]]) -> list[Mapping[str, Any]]: + return [row async for row in gen] + + +@pytest.mark.asyncio +async def test_stream_records_delegates_to_store() -> None: + service, store = _service([{"id": "a"}, {"id": "b"}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + rows = await _drain(service.stream_records(plan)) + assert rows == [{"id": "a"}, {"id": "b"}] + assert store.received_plan is plan + + +@pytest.mark.asyncio +async def test_stream_records_rejects_feature_plan() -> None: + service, _ = _service() + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + with pytest.raises(ValidationError): + await _drain(service.stream_records(plan)) + + +@pytest.mark.asyncio +async def test_filter_depth_bound_enforced() -> None: + service, _ = _service() + service.config.discovery_max_filter_depth = 1 + deep = And(operands=[_pred(), And(operands=[_pred(), _pred()])]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, filter=deep) + with pytest.raises(ValidationError) as exc: + await _drain(service.stream_records(plan)) + assert exc.value.code == "filter_depth_exceeded" + + +@pytest.mark.asyncio +async def test_filter_predicate_count_bound_enforced() -> None: + service, _ = _service() + service.config.discovery_max_predicates = 1 + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + filter=And(operands=[_pred(), _pred()]), + ) + with pytest.raises(ValidationError) as exc: + await _drain(service.stream_records(plan)) + assert exc.value.code == "filter_predicates_exceeded" + + +@pytest.mark.asyncio +async def test_no_filter_streams_cleanly() -> None: + service, _ = _service([{"id": "x"}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + assert await _drain(service.stream_records(plan)) == [{"id": "x"}] diff --git a/server/tests/unit/domain/data/test_query_plan.py b/server/tests/unit/domain/data/test_query_plan.py new file mode 100644 index 0000000..0690d30 --- /dev/null +++ b/server/tests/unit/domain/data/test_query_plan.py @@ -0,0 +1,61 @@ +"""T029 — QueryPlan validation rules and per-kind default sorts.""" + +import pytest +from pydantic import ValidationError as PydanticValidationError + +from osa.domain.data.model.query_plan import ( + PaginationParams, + QueryPlan, + SortDirection, + TableKind, + decode_cursor, + encode_cursor, +) +from osa.domain.shared.model.srn import SchemaId + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def test_records_plan_defaults_to_created_at_id_desc() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + assert [(s.column, s.direction) for s in plan.sort] == [ + ("created_at", SortDirection.DESC), + ("id", SortDirection.DESC), + ] + + +def test_feature_plan_defaults_to_id_asc() -> None: + plan = QueryPlan( + schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="chemical_features" + ) + assert [(s.column, s.direction) for s in plan.sort] == [("id", SortDirection.ASC)] + + +def test_feature_kind_requires_feature_name() -> None: + with pytest.raises(PydanticValidationError): + QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE) + + +def test_records_kind_rejects_feature_name() -> None: + with pytest.raises(PydanticValidationError): + QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, feature_name="x") + + +def test_limit_capped_at_1000() -> None: + with pytest.raises(PydanticValidationError): + PaginationParams(limit=5000) + + +def test_limit_default_is_50() -> None: + assert PaginationParams().limit == 50 + + +def test_cursor_roundtrip() -> None: + enc = encode_cursor("2026-01-01T00:00:00", "abc") + decoded = decode_cursor(enc) + assert decoded == {"s": "2026-01-01T00:00:00", "id": "abc"} + + +def test_decode_malformed_cursor_raises() -> None: + with pytest.raises(ValueError): + decode_cursor("not-valid-base64-json!!!") diff --git a/server/tests/unit/domain/deposition/test_hook_reserved_name.py b/server/tests/unit/domain/deposition/test_hook_reserved_name.py new file mode 100644 index 0000000..b20650c --- /dev/null +++ b/server/tests/unit/domain/deposition/test_hook_reserved_name.py @@ -0,0 +1,36 @@ +"""T065 — HookDefinition rejects reserved hook names (records, datasets).""" + +import pytest + +from osa.domain.shared.error import ReservedNameError +from osa.domain.shared.model.hook import ( + ColumnDef, + HookDefinition, + OciConfig, + TableFeatureSpec, +) + + +def _hook(name: str) -> HookDefinition: + return HookDefinition( + name=name, + runtime=OciConfig(image="ghcr.io/example/hook", digest="sha256:abc123"), + feature=TableFeatureSpec( + cardinality="one", + columns=[ColumnDef(name="score", json_type="number", required=True)], + ), + ) + + +@pytest.mark.parametrize("reserved", ["records", "datasets"]) +def test_hook_rejects_reserved_name(reserved: str) -> None: + with pytest.raises(ReservedNameError) as exc: + _hook(reserved) + assert exc.value.code == "reserved_name" + assert exc.value.kind == "hook" + assert exc.value.name == reserved + + +def test_hook_allows_non_reserved_name() -> None: + hook = _hook("chemical_features") + assert hook.name == "chemical_features" diff --git a/server/tests/unit/domain/discovery/__init__.py b/server/tests/unit/domain/discovery/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/tests/unit/domain/discovery/test_discovery_service.py b/server/tests/unit/domain/discovery/test_discovery_service.py deleted file mode 100644 index b68dbba..0000000 --- a/server/tests/unit/domain/discovery/test_discovery_service.py +++ /dev/null @@ -1,476 +0,0 @@ -"""Tests for DiscoveryService — FilterExpr validation, operator validation, delegation.""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - ColumnInfo, - FeatureCatalogEntry, - FeatureRow, - FilterOperator, - Predicate, - RecordSummary, - SortOrder, - decode_cursor, - encode_cursor, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import NotFoundError, ValidationError -from osa.domain.shared.model.srn import RecordSRN, SchemaId - - -SCHEMA_SRN = SchemaId.parse("bio-sample@1.0.0") - - -def _config() -> Config: - # Build a Config with minimal auth — tests don't hit JWT paths - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) # Test-only secret - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -@pytest.fixture -def mock_read_store() -> AsyncMock: - store = AsyncMock() - store.search_records.return_value = [] - return store - - -@pytest.fixture -def mock_field_reader() -> AsyncMock: - reader = AsyncMock() - reader.get_all_field_types.return_value = { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TERM, - "published_date": FieldType.DATE, - "is_public": FieldType.BOOLEAN, - "homepage": FieldType.URL, - } - reader.get_fields_for_schema.return_value = { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TERM, - "published_date": FieldType.DATE, - "is_public": FieldType.BOOLEAN, - "homepage": FieldType.URL, - } - return reader - - -@pytest.fixture -def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: - return DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - - -def _eq(field: str, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=FilterOperator.EQ, value=value) - - -class TestSearchRecordsValidation: - async def test_rejects_unknown_filter_field(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="Unknown metadata field 'bogus'"): - await service.search_records( - filter_expr=_eq("bogus", "x"), - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_rejects_invalid_operator_for_type(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="not valid"): - await service.search_records( - filter_expr=Predicate( - field=MetadataFieldRef(field="resolution"), - op=FilterOperator.CONTAINS, - value="x", - ), - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_rejects_unknown_sort_field(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="Unknown sort field"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="nonexistent", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_accepts_published_at_sort(self, service: DiscoveryService) -> None: - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert result.results == [] - - async def test_accepts_metadata_field_sort(self, service: DiscoveryService) -> None: - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="resolution", - order=SortOrder.ASC, - cursor=None, - limit=20, - ) - assert result.results == [] - - async def test_rejects_limit_too_low(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="limit"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=0, - ) - - async def test_rejects_limit_too_high(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="limit"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=101, - ) - - async def test_raises_not_found_for_unknown_schema(self, mock_read_store: AsyncMock) -> None: - """Pinning an unregistered schema must raise NotFoundError, not silently - fall through to an unscoped query that returns cross-schema records.""" - empty_reader = AsyncMock() - empty_reader.get_fields_for_schema.return_value = {} - svc = DiscoveryService( - read_store=mock_read_store, - field_reader=empty_reader, - config=_config(), - ) - - with pytest.raises(NotFoundError, match="Schema not found"): - await svc.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - mock_read_store.search_records.assert_not_called() - - async def test_search_features_raises_not_found_for_unknown_schema( - self, mock_read_store: AsyncMock - ) -> None: - """search_features must also guard against unknown schema pins.""" - mock_read_store.get_feature_table_schema.return_value = FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ColumnInfo(name="score", type="number", required=False)], - record_count=0, - ) - empty_reader = AsyncMock() - empty_reader.get_fields_for_schema.return_value = {} - svc = DiscoveryService( - read_store=mock_read_store, - field_reader=empty_reader, - config=_config(), - ) - - with pytest.raises(NotFoundError, match="Schema not found"): - await svc.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=SCHEMA_SRN, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - mock_read_store.search_features.assert_not_called() - - async def test_rejects_q_when_no_text_fields(self, mock_read_store: AsyncMock) -> None: - no_text_reader = AsyncMock() - no_text_reader.get_all_field_types.return_value = {"resolution": FieldType.NUMBER} - no_text_reader.get_fields_for_schema.return_value = {"resolution": FieldType.NUMBER} - svc = DiscoveryService( - read_store=mock_read_store, - field_reader=no_text_reader, - config=_config(), - ) - - with pytest.raises(ValidationError, match="Free-text search is unavailable"): - await svc.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q="kinase", - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestSearchRecordsDelegation: - async def test_delegates_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - await service.search_records( - filter_expr=_eq("method", "X-ray"), - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - mock_read_store.search_records.assert_called_once() - call_kwargs = mock_read_store.search_records.call_args - assert call_kwargs.kwargs["filter_expr"] is not None - assert call_kwargs.kwargs["q"] is None - assert call_kwargs.kwargs["sort"] == "published_at" - assert call_kwargs.kwargs["limit"] == 21 # N+1 - - async def test_decodes_cursor( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - cursor = encode_cursor("2026-01-01", "urn:osa:localhost:rec:abc@1") - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=cursor, - limit=20, - ) - - call_kwargs = mock_read_store.search_records.call_args - decoded = call_kwargs.kwargs["cursor"] - assert decoded["s"] == "2026-01-01" - assert decoded["id"] == "urn:osa:localhost:rec:abc@1" - - async def test_invalid_cursor_raises(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="cursor"): - await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor="not-a-cursor!!!", - limit=20, - ) - - async def test_encodes_next_cursor_from_results( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - ts = datetime(2026, 1, 1, tzinfo=UTC) - mock_read_store.search_records.return_value = [ - RecordSummary(srn=srn, published_at=ts, metadata={"title": f"r{i}"}) for i in range(2) - ] - - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=1, - ) - - assert result.has_more is True - assert result.cursor is not None - assert len(result.results) == 1 - decoded = decode_cursor(result.cursor) - assert decoded["id"] == str(srn) - - async def test_no_cursor_when_no_more_results( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - mock_read_store.search_records.return_value = [] - - result = await service.search_records( - filter_expr=None, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - assert result.cursor is None - assert result.has_more is False - - -class TestSchemaRequiredGuards: - """With the JSONB filter fallback removed, any query that resolves against - metadata fields must pin a schema.""" - - async def test_metadata_predicate_without_schema_raises( - self, service: DiscoveryService - ) -> None: - with pytest.raises(ValidationError) as exc: - await service.search_records( - filter_expr=_eq("title", "x"), - schema_id=None, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert exc.value.code == "schema_required_for_metadata_query" - - async def test_non_default_sort_without_schema_raises(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError) as exc: - await service.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - q=None, - sort="resolution", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert exc.value.code == "schema_required_for_metadata_sort" - - async def test_q_without_schema_raises(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError) as exc: - await service.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - q="kinase", - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert exc.value.code == "schema_required_for_free_text_search" - - async def test_plain_listing_without_schema_succeeds(self, service: DiscoveryService) -> None: - """No filter, default sort, no q → unscoped listing is allowed.""" - result = await service.search_records( - filter_expr=None, - schema_id=None, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - assert result.results == [] - - -class TestFilterBounds: - async def test_depth_exceeded_raises(self, service: DiscoveryService) -> None: - # Build a nest of AND that exceeds the default depth (10) - leaf = _eq("title", "r") - tree = leaf - for _ in range(11): - tree = And(operands=[tree, leaf]) - - with pytest.raises(ValidationError, match="filter_depth_exceeded|depth"): - await service.search_records( - filter_expr=tree, - schema_id=SCHEMA_SRN, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestFeatureCursorEncoding: - async def test_cursor_encodes_row_id( - self, mock_read_store: AsyncMock, mock_field_reader: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_read_store.get_feature_table_schema.return_value = FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ColumnInfo(name="score", type="number", required=True)], - record_count=0, - ) - mock_read_store.search_features.return_value = [ - FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66}), - FeatureRow(row_id=43, record_srn=srn, data={"score": 6.0}), - ] - - service = DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - result = await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=1, - ) - - assert result.has_more is True - assert result.cursor is not None - assert len(result.rows) == 1 - decoded = decode_cursor(result.cursor) - assert decoded["id"] == 42 - assert decoded["s"] == 7.66 diff --git a/server/tests/unit/domain/discovery/test_get_feature_catalog.py b/server/tests/unit/domain/discovery/test_get_feature_catalog.py deleted file mode 100644 index f6197e3..0000000 --- a/server/tests/unit/domain/discovery/test_get_feature_catalog.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Tests for GetFeatureCatalogHandler and DiscoveryService.get_feature_catalog().""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.domain.discovery.model.value import ColumnInfo, FeatureCatalogEntry -from osa.domain.discovery.query.get_feature_catalog import ( - GetFeatureCatalog, - GetFeatureCatalogHandler, - GetFeatureCatalogResult, -) -from osa.config import Config -from osa.domain.discovery.service.discovery import DiscoveryService - - -def _config() -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -@pytest.fixture -def mock_read_store() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def mock_field_reader() -> AsyncMock: - reader = AsyncMock() - reader.get_all_field_types.return_value = {} - reader.get_fields_for_schema.return_value = {} - return reader - - -@pytest.fixture -def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: - return DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - - -class TestGetFeatureCatalogHandler: - async def test_public_auth_gate(self) -> None: - from osa.domain.shared.authorization.gate import Public - - assert isinstance(GetFeatureCatalogHandler.__auth__, Public) - - async def test_delegates_to_service(self) -> None: - mock_service = AsyncMock() - from osa.domain.discovery.model.value import FeatureCatalog - - mock_service.get_feature_catalog.return_value = FeatureCatalog(tables=[]) - - handler = GetFeatureCatalogHandler(discovery_service=mock_service) - result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) - - assert result.tables == [] - mock_service.get_feature_catalog.assert_called_once() - - async def test_returns_correct_structure(self) -> None: - mock_service = AsyncMock() - from osa.domain.discovery.model.value import FeatureCatalog - - mock_service.get_feature_catalog.return_value = FeatureCatalog( - tables=[ - FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ - ColumnInfo(name="score", type="number", required=True), - ColumnInfo(name="volume", type="number", required=True), - ], - record_count=142, - ) - ] - ) - - handler = GetFeatureCatalogHandler(discovery_service=mock_service) - result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) - - assert len(result.tables) == 1 - assert result.tables[0]["hook_name"] == "detect_pockets" - assert result.tables[0]["record_count"] == 142 - assert len(result.tables[0]["columns"]) == 2 - - -class TestDiscoveryServiceGetFeatureCatalog: - async def test_delegates_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - mock_read_store.get_feature_catalog.return_value = [] - result = await service.get_feature_catalog() - assert result.tables == [] - mock_read_store.get_feature_catalog.assert_called_once() - - async def test_returns_entries_from_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - entries = [ - FeatureCatalogEntry( - hook_name="test_hook", - columns=[ColumnInfo(name="x", type="number", required=True)], - record_count=10, - ) - ] - mock_read_store.get_feature_catalog.return_value = entries - - result = await service.get_feature_catalog() - assert len(result.tables) == 1 - assert result.tables[0].hook_name == "test_hook" diff --git a/server/tests/unit/domain/discovery/test_search_features.py b/server/tests/unit/domain/discovery/test_search_features.py deleted file mode 100644 index b89494b..0000000 --- a/server/tests/unit/domain/discovery/test_search_features.py +++ /dev/null @@ -1,275 +0,0 @@ -"""Tests for SearchFeaturesHandler and DiscoveryService.search_features().""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import FeatureFieldRef -from osa.domain.discovery.model.value import ( - ColumnInfo, - FeatureCatalogEntry, - FeatureRow, - FeatureSearchResult, - FilterOperator, - Predicate, - SortOrder, -) -from osa.domain.discovery.query.search_features import ( - SearchFeatures, - SearchFeaturesHandler, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.shared.error import NotFoundError, ValidationError -from osa.domain.shared.model.srn import RecordSRN - - -def _make_catalog_entry() -> FeatureCatalogEntry: - return FeatureCatalogEntry( - hook_name="detect_pockets", - columns=[ - ColumnInfo(name="score", type="number", required=True), - ColumnInfo(name="volume", type="number", required=True), - ColumnInfo(name="label", type="string", required=False), - ColumnInfo(name="is_active", type="boolean", required=False), - ], - record_count=10, - ) - - -def _config() -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -def _predicate(hook: str, column: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=FeatureFieldRef(hook=hook, column=column), op=op, value=value) - - -@pytest.fixture -def mock_read_store() -> AsyncMock: - store = AsyncMock() - store.get_feature_table_schema.return_value = _make_catalog_entry() - store.search_features.return_value = [] - return store - - -@pytest.fixture -def mock_field_reader() -> AsyncMock: - reader = AsyncMock() - reader.get_all_field_types.return_value = {} - reader.get_fields_for_schema.return_value = {} - return reader - - -@pytest.fixture -def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: - return DiscoveryService( - read_store=mock_read_store, - field_reader=mock_field_reader, - config=_config(), - ) - - -class TestSearchFeaturesHandler: - async def test_public_auth_gate(self) -> None: - from osa.domain.shared.authorization.gate import Public - - assert isinstance(SearchFeaturesHandler.__auth__, Public) - - async def test_delegates_to_service(self) -> None: - mock_service = AsyncMock() - mock_service.search_features.return_value = FeatureSearchResult( - rows=[], cursor=None, has_more=False - ) - - handler = SearchFeaturesHandler(discovery_service=mock_service) - await handler.run(SearchFeatures(hook_name="detect_pockets")) - mock_service.search_features.assert_called_once() - - async def test_invalid_record_srn_raises_validation_error(self) -> None: - mock_service = AsyncMock() - handler = SearchFeaturesHandler(discovery_service=mock_service) - - with pytest.raises(ValidationError, match="not an OSA SRN") as exc_info: - await handler.run(SearchFeatures(hook_name="detect_pockets", record_srn="not-a-srn")) - - assert exc_info.value.field == "record_srn" - mock_service.search_features.assert_not_called() - - async def test_maps_rows_with_record_srn(self) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_service = AsyncMock() - mock_service.search_features.return_value = FeatureSearchResult( - rows=[FeatureRow(row_id=1, record_srn=srn, data={"score": 7.66})], - cursor=None, - has_more=False, - ) - - handler = SearchFeaturesHandler(discovery_service=mock_service) - result = await handler.run(SearchFeatures(hook_name="detect_pockets")) - - assert result.rows[0]["record_srn"] == str(srn) - assert result.rows[0]["score"] == 7.66 - - -class TestDiscoveryServiceSearchFeatures: - async def test_raises_not_found_for_unknown_hook( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - mock_read_store.get_feature_table_schema.return_value = None - - with pytest.raises(NotFoundError): - await service.search_features( - hook_name="unknown_hook", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_rejects_unknown_column(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="bogus"): - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "bogus", FilterOperator.EQ, 1), - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_validates_operator_for_number_column(self, service: DiscoveryService) -> None: - with pytest.raises(ValidationError, match="contains"): - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "score", FilterOperator.CONTAINS, "x"), - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_accepts_string_contains_operator(self, service: DiscoveryService) -> None: - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "label", FilterOperator.CONTAINS, "test"), - schema_id=None, - record_srn=None, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - async def test_passes_record_srn_filter( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=srn, - sort="id", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - call_kwargs = mock_read_store.search_features.call_args - assert call_kwargs.kwargs["record_srn"] == srn - - async def test_delegates_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - await service.search_features( - hook_name="detect_pockets", - filter_expr=_predicate("detect_pockets", "score", FilterOperator.GTE, 6.0), - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - mock_read_store.search_features.assert_called_once() - call_kwargs = mock_read_store.search_features.call_args - assert call_kwargs.kwargs["hook_name"] == "detect_pockets" - assert call_kwargs.kwargs["filter_expr"] is not None - - -class TestSearchFeaturesPagination: - async def test_has_more_false_when_exactly_limit_rows( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_read_store.search_features.return_value = [ - FeatureRow(row_id=i, record_srn=srn, data={"score": float(i)}) for i in range(3) - ] - - result = await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=3, - ) - - assert result.has_more is False - assert result.cursor is None - assert len(result.rows) == 3 - - async def test_has_more_true_when_more_than_limit_rows( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - mock_read_store.search_features.return_value = [ - FeatureRow(row_id=i, record_srn=srn, data={"score": float(i)}) for i in range(4) - ] - - result = await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=3, - ) - - assert result.has_more is True - assert result.cursor is not None - assert len(result.rows) == 3 - - async def test_passes_limit_plus_one_to_read_store( - self, service: DiscoveryService, mock_read_store: AsyncMock - ) -> None: - await service.search_features( - hook_name="detect_pockets", - filter_expr=None, - schema_id=None, - record_srn=None, - sort="score", - order=SortOrder.DESC, - cursor=None, - limit=50, - ) - - call_kwargs = mock_read_store.search_features.call_args - assert call_kwargs.kwargs["limit"] == 51 diff --git a/server/tests/unit/domain/discovery/test_search_records.py b/server/tests/unit/domain/discovery/test_search_records.py deleted file mode 100644 index b9a7fbd..0000000 --- a/server/tests/unit/domain/discovery/test_search_records.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Tests for SearchRecordsHandler — auth gate, delegation, result mapping.""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock - -import pytest - -from osa.domain.discovery.model.value import RecordSearchResult, RecordSummary, SortOrder -from osa.domain.discovery.query.search_records import ( - SearchRecords, - SearchRecordsHandler, - SearchRecordsResult, -) -from osa.domain.shared.model.srn import RecordSRN - - -@pytest.fixture -def mock_service() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def handler(mock_service: AsyncMock) -> SearchRecordsHandler: - return SearchRecordsHandler(discovery_service=mock_service) - - -class TestSearchRecordsHandler: - async def test_public_auth_gate(self, handler: SearchRecordsHandler) -> None: - from osa.domain.shared.authorization.gate import Public - - assert isinstance(handler.__auth__, Public) - - async def test_delegates_to_service( - self, handler: SearchRecordsHandler, mock_service: AsyncMock - ) -> None: - mock_service.search_records.return_value = RecordSearchResult( - results=[], cursor=None, has_more=False - ) - cmd = SearchRecords() - await handler.run(cmd) - - mock_service.search_records.assert_called_once_with( - filter_expr=None, - schema_id=None, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_maps_results( - self, handler: SearchRecordsHandler, mock_service: AsyncMock - ) -> None: - srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") - ts = datetime(2026, 1, 1, tzinfo=UTC) - mock_service.search_records.return_value = RecordSearchResult( - results=[RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"})], - cursor="abc123", - has_more=False, - ) - - result: SearchRecordsResult = await handler.run(SearchRecords()) - - assert result.cursor == "abc123" - assert result.has_more is False - assert result.results[0]["srn"] == str(srn) - assert result.results[0]["metadata"] == {"title": "Test"} diff --git a/server/tests/unit/domain/discovery/test_value.py b/server/tests/unit/domain/discovery/test_value.py deleted file mode 100644 index 18518f5..0000000 --- a/server/tests/unit/domain/discovery/test_value.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Tests for discovery domain value objects — cursor helpers and VALID_OPERATORS.""" - -import pytest - -from osa.domain.discovery.model.value import ( - VALID_OPERATORS, - FilterOperator, - decode_cursor, - encode_cursor, -) -from osa.domain.semantics.model.value import FieldType - - -class TestCursorRoundTrip: - def test_round_trip_string_values(self) -> None: - cursor = encode_cursor("2026-01-01", "urn:osa:localhost:rec:abc@1") - decoded = decode_cursor(cursor) - assert decoded["s"] == "2026-01-01" - assert decoded["id"] == "urn:osa:localhost:rec:abc@1" - - def test_round_trip_numeric_values(self) -> None: - cursor = encode_cursor(7.66, 123) - decoded = decode_cursor(cursor) - assert decoded["s"] == 7.66 - assert decoded["id"] == 123 - - def test_round_trip_none_sort_value(self) -> None: - cursor = encode_cursor(None, "urn:osa:localhost:rec:abc@1") - decoded = decode_cursor(cursor) - assert decoded["s"] is None - assert decoded["id"] == "urn:osa:localhost:rec:abc@1" - - -class TestDecodeCursorErrors: - def test_malformed_base64(self) -> None: - with pytest.raises(ValueError, match="Malformed cursor"): - decode_cursor("not-valid-base64!!!") - - def test_invalid_json(self) -> None: - import base64 - - bad = base64.urlsafe_b64encode(b"not json").decode() - with pytest.raises(ValueError, match="Malformed cursor"): - decode_cursor(bad) - - def test_missing_s_key(self) -> None: - import base64 - import json - - bad = base64.urlsafe_b64encode(json.dumps({"id": "x"}).encode()).decode() - with pytest.raises(ValueError, match="'s' and 'id'"): - decode_cursor(bad) - - def test_missing_id_key(self) -> None: - import base64 - import json - - bad = base64.urlsafe_b64encode(json.dumps({"s": 1}).encode()).decode() - with pytest.raises(ValueError, match="'s' and 'id'"): - decode_cursor(bad) - - def test_non_dict_payload(self) -> None: - import base64 - import json - - bad = base64.urlsafe_b64encode(json.dumps([1, 2]).encode()).decode() - with pytest.raises(ValueError, match="'s' and 'id'"): - decode_cursor(bad) - - -class TestValidOperators: - def test_text_operators_include_basics(self) -> None: - ops = VALID_OPERATORS[FieldType.TEXT] - assert FilterOperator.EQ in ops - assert FilterOperator.CONTAINS in ops - assert FilterOperator.IN in ops - assert FilterOperator.NEQ in ops - - def test_url_operators_include_basics(self) -> None: - ops = VALID_OPERATORS[FieldType.URL] - assert FilterOperator.EQ in ops - assert FilterOperator.CONTAINS in ops - assert FilterOperator.IN in ops - - def test_number_operators_support_ordering(self) -> None: - ops = VALID_OPERATORS[FieldType.NUMBER] - assert FilterOperator.EQ in ops - assert FilterOperator.GT in ops - assert FilterOperator.GTE in ops - assert FilterOperator.LT in ops - assert FilterOperator.LTE in ops - - def test_date_operators_support_ordering(self) -> None: - ops = VALID_OPERATORS[FieldType.DATE] - assert FilterOperator.GTE in ops - assert FilterOperator.LTE in ops - - def test_boolean_operators(self) -> None: - assert FilterOperator.EQ in VALID_OPERATORS[FieldType.BOOLEAN] - assert FilterOperator.IS_NULL in VALID_OPERATORS[FieldType.BOOLEAN] - - def test_term_operators(self) -> None: - assert FilterOperator.EQ in VALID_OPERATORS[FieldType.TERM] - - def test_all_field_types_have_operators(self) -> None: - for ft in FieldType: - assert ft in VALID_OPERATORS, f"Missing operators for {ft}" diff --git a/server/tests/unit/domain/index/__init__.py b/server/tests/unit/domain/index/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/tests/unit/domain/index/test_fanout_listener.py b/server/tests/unit/domain/index/test_fanout_listener.py deleted file mode 100644 index 2aa0988..0000000 --- a/server/tests/unit/domain/index/test_fanout_listener.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Unit tests for FanOutToIndexBackends handler.""" - -from typing import Any -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import DepositionSource -from osa.domain.shared.model.srn import ( - ConventionSRN, - DepositionSRN, - Domain, - LocalId, - RecordSRN, - RecordVersion, -) - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str): - self._name = name - - @property - def name(self) -> str: - return self._name - - -class FakeOutbox: - """Fake outbox for testing event emission.""" - - def __init__(self): - self.events: list[Any] = [] - self.append = AsyncMock(side_effect=self._append) - - async def _append(self, event: Any) -> None: - self.events.append(event) - - -@pytest.fixture -def sample_record_srn() -> RecordSRN: - """Create a sample record SRN.""" - return RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ) - - -@pytest.fixture -def sample_deposition_srn() -> DepositionSRN: - """Create a sample deposition SRN.""" - return DepositionSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - ) - - -@pytest.fixture -def sample_metadata() -> dict: - """Create sample metadata for testing.""" - return { - "title": "Test Record", - "organism": "human", - "platform": "GPL570", - } - - -class TestFanOutToIndexBackends: - """Tests for FanOutToIndexBackends handler.""" - - @pytest.mark.asyncio - async def test_creates_index_record_per_backend( - self, - sample_record_srn: RecordSRN, - sample_deposition_srn: DepositionSRN, - sample_metadata: dict, - ): - """Handler should create one IndexRecord event per registered backend.""" - # Arrange - backend1 = FakeBackend("vector") - backend2 = FakeBackend("keyword") - registry = IndexRegistry({"vector": backend1, "keyword": backend2}) - outbox = FakeOutbox() - - handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) - - event = RecordPublished( - id=EventId(uuid4()), - record_srn=sample_record_srn, - source=DepositionSource(id=str(sample_deposition_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=__import__( - "osa.domain.shared.model.srn", fromlist=["SchemaId"] - ).SchemaId.parse("test@1.0.0"), - metadata=sample_metadata, - ) - - # Act - await handler.handle(event) - - # Assert - assert len(outbox.events) == 2 - backend_names = {e.backend_name for e in outbox.events} - assert backend_names == {"vector", "keyword"} - - for index_event in outbox.events: - assert isinstance(index_event, IndexRecord) - assert index_event.record_srn == sample_record_srn - assert index_event.metadata == sample_metadata - - @pytest.mark.asyncio - async def test_creates_unique_event_ids( - self, - sample_record_srn: RecordSRN, - sample_deposition_srn: DepositionSRN, - sample_metadata: dict, - ): - """Each IndexRecord should have a unique event ID.""" - # Arrange - registry = IndexRegistry( - { - "backend1": FakeBackend("backend1"), - "backend2": FakeBackend("backend2"), - } - ) - outbox = FakeOutbox() - - handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) - - event = RecordPublished( - id=EventId(uuid4()), - record_srn=sample_record_srn, - source=DepositionSource(id=str(sample_deposition_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=__import__( - "osa.domain.shared.model.srn", fromlist=["SchemaId"] - ).SchemaId.parse("test@1.0.0"), - metadata=sample_metadata, - ) - - # Act - await handler.handle(event) - - # Assert - event_ids = [e.id for e in outbox.events] - assert len(event_ids) == len(set(event_ids)), "Event IDs should be unique" - - @pytest.mark.asyncio - async def test_empty_registry_creates_no_events( - self, - sample_record_srn: RecordSRN, - sample_deposition_srn: DepositionSRN, - sample_metadata: dict, - ): - """No IndexRecord events should be created if registry is empty.""" - # Arrange - registry = IndexRegistry({}) - outbox = FakeOutbox() - - handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) - - event = RecordPublished( - id=EventId(uuid4()), - record_srn=sample_record_srn, - source=DepositionSource(id=str(sample_deposition_srn)), - convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0"), - schema_id=__import__( - "osa.domain.shared.model.srn", fromlist=["SchemaId"] - ).SchemaId.parse("test@1.0.0"), - metadata=sample_metadata, - ) - - # Act - await handler.handle(event) - - # Assert - assert len(outbox.events) == 0 diff --git a/server/tests/unit/domain/index/test_index_service.py b/server/tests/unit/domain/index/test_index_service.py deleted file mode 100644 index 9769f97..0000000 --- a/server/tests/unit/domain/index/test_index_service.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Unit tests for IndexService.""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.index.service.index import IndexService - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str, count: int = 0, healthy: bool = True): - self._name = name - self._count = count - self._healthy = healthy - self.count = AsyncMock(return_value=count) - self.health = AsyncMock(return_value=healthy) - - @property - def name(self) -> str: - return self._name - - -class TestIndexService: - """Tests for IndexService.""" - - @pytest.mark.asyncio - async def test_get_count_returns_backend_count(self): - """get_count should return the backend's document count.""" - # Arrange - backend = FakeBackend("vector", count=42) - registry = IndexRegistry({"vector": backend}) - service = IndexService(indexes=registry) - - # Act - result = await service.get_count("vector") - - # Assert - assert result == 42 - backend.count.assert_called_once() - - @pytest.mark.asyncio - async def test_get_count_returns_none_for_unknown_backend(self): - """get_count should return None for unknown backend.""" - # Arrange - registry = IndexRegistry({}) - service = IndexService(indexes=registry) - - # Act - result = await service.get_count("unknown") - - # Assert - assert result is None - - @pytest.mark.asyncio - async def test_check_health_returns_backend_health(self): - """check_health should return the backend's health status.""" - # Arrange - backend = FakeBackend("vector", healthy=True) - registry = IndexRegistry({"vector": backend}) - service = IndexService(indexes=registry) - - # Act - result = await service.check_health("vector") - - # Assert - assert result is True - backend.health.assert_called_once() - - @pytest.mark.asyncio - async def test_check_health_returns_none_for_unknown_backend(self): - """check_health should return None for unknown backend.""" - # Arrange - registry = IndexRegistry({}) - service = IndexService(indexes=registry) - - # Act - result = await service.check_health("unknown") - - # Assert - assert result is None - - @pytest.mark.asyncio - async def test_check_health_returns_false_for_unhealthy_backend(self): - """check_health should return False for unhealthy backend.""" - # Arrange - backend = FakeBackend("vector", healthy=False) - registry = IndexRegistry({"vector": backend}) - service = IndexService(indexes=registry) - - # Act - result = await service.check_health("vector") - - # Assert - assert result is False diff --git a/server/tests/unit/domain/semantics/test_schema_reserved_name.py b/server/tests/unit/domain/semantics/test_schema_reserved_name.py new file mode 100644 index 0000000..335c47f --- /dev/null +++ b/server/tests/unit/domain/semantics/test_schema_reserved_name.py @@ -0,0 +1,40 @@ +"""T064 — Schema.__init__ rejects reserved schema IDs (records, datasets).""" + +from datetime import UTC, datetime + +import pytest + +from osa.domain.semantics.model.schema import Schema +from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType +from osa.domain.shared.error import ReservedNameError +from osa.domain.shared.model.srn import SchemaId + + +def _text_field(name: str = "title") -> FieldDefinition: + return FieldDefinition( + name=name, type=FieldType.TEXT, required=True, cardinality=Cardinality.EXACTLY_ONE + ) + + +@pytest.mark.parametrize("reserved", ["records", "datasets"]) +def test_schema_rejects_reserved_id(reserved: str) -> None: + with pytest.raises(ReservedNameError) as exc: + Schema( + id=SchemaId.parse(f"{reserved}@1.0.0"), + title="x", + fields=[_text_field()], + created_at=datetime.now(UTC), + ) + assert exc.value.code == "reserved_name" + assert exc.value.kind == "schema" + assert exc.value.name == reserved + + +def test_schema_allows_non_reserved_id() -> None: + schema = Schema( + id=SchemaId.parse("compound@1.0.0"), + title="Compound", + fields=[_text_field()], + created_at=datetime.now(UTC), + ) + assert schema.id.id.root == "compound" diff --git a/server/osa/domain/search/adapter/__init__.py b/server/tests/unit/domain/shared/__init__.py similarity index 100% rename from server/osa/domain/search/adapter/__init__.py rename to server/tests/unit/domain/shared/__init__.py diff --git a/server/tests/unit/domain/shared/test_reserved.py b/server/tests/unit/domain/shared/test_reserved.py new file mode 100644 index 0000000..7d7ef21 --- /dev/null +++ b/server/tests/unit/domain/shared/test_reserved.py @@ -0,0 +1,33 @@ +"""T028 — RESERVED_NAMES constant and ReservedNameError.""" + +from osa.domain.shared.error import DomainError, ReservedNameError +from osa.domain.shared.model.reserved import RESERVED_NAMES + + +def test_reserved_names_contains_records_and_datasets() -> None: + assert RESERVED_NAMES == frozenset({"records", "datasets"}) + + +def test_reserved_names_is_frozen() -> None: + assert isinstance(RESERVED_NAMES, frozenset) + + +def test_reserved_name_error_is_domain_error() -> None: + err = ReservedNameError("records", "schema") + assert isinstance(err, DomainError) + + +def test_reserved_name_error_code_and_message() -> None: + err = ReservedNameError("records", "schema") + assert err.code == "reserved_name" + assert err.name == "records" + assert err.kind == "schema" + assert "records" in err.message + assert "/data/" in err.message + + +def test_reserved_name_error_lists_reserved_names() -> None: + err = ReservedNameError("datasets", "hook") + # sorted reserved set is rendered in the message + assert "datasets" in err.message + assert "records" in err.message diff --git a/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py b/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py deleted file mode 100644 index 25974c2..0000000 --- a/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Tests for PostgresFieldDefinitionReader — field type map construction.""" - -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.infrastructure.persistence.adapter.discovery import PostgresFieldDefinitionReader - - -def _make_schema_row(srn: str, fields: list[dict]) -> dict: - return {"srn": srn, "fields": fields} - - -@pytest.fixture -def mock_session() -> AsyncMock: - return AsyncMock() - - -def _setup_session_result(session: AsyncMock, rows: list[dict]) -> None: - """Configure mock session to return rows from a SELECT on schemas_table.""" - result_mock = MagicMock() - result_mock.mappings.return_value.all.return_value = rows - session.execute.return_value = result_mock - - -class TestGetAllFieldTypes: - async def test_builds_field_map_from_multiple_schemas(self, mock_session: AsyncMock) -> None: - rows = [ - _make_schema_row( - "urn:osa:localhost:schema:a@1", - [ - { - "name": "title", - "type": "text", - "required": True, - "cardinality": "exactly_one", - }, - { - "name": "resolution", - "type": "number", - "required": False, - "cardinality": "exactly_one", - }, - ], - ), - _make_schema_row( - "urn:osa:localhost:schema:b@1", - [ - { - "name": "method", - "type": "term", - "required": True, - "cardinality": "exactly_one", - "constraints": { - "type": "term", - "ontology_srn": "urn:osa:localhost:onto:methods@1", - }, - }, - ], - ), - ] - _setup_session_result(mock_session, rows) - - reader = PostgresFieldDefinitionReader(session=mock_session) - result = await reader.get_all_field_types() - - assert result == { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - "method": FieldType.TERM, - } - - async def test_raises_on_conflicting_types(self, mock_session: AsyncMock) -> None: - rows = [ - _make_schema_row( - "urn:osa:localhost:schema:a@1", - [{"name": "value", "type": "text", "required": True, "cardinality": "exactly_one"}], - ), - _make_schema_row( - "urn:osa:localhost:schema:b@1", - [ - { - "name": "value", - "type": "number", - "required": True, - "cardinality": "exactly_one", - } - ], - ), - ] - _setup_session_result(mock_session, rows) - - reader = PostgresFieldDefinitionReader(session=mock_session) - with pytest.raises(ValidationError, match="value"): - await reader.get_all_field_types() - - async def test_returns_empty_map_when_no_schemas(self, mock_session: AsyncMock) -> None: - _setup_session_result(mock_session, []) - - reader = PostgresFieldDefinitionReader(session=mock_session) - result = await reader.get_all_field_types() - - assert result == {} - - async def test_same_field_same_type_across_schemas_ok(self, mock_session: AsyncMock) -> None: - rows = [ - _make_schema_row( - "urn:osa:localhost:schema:a@1", - [{"name": "title", "type": "text", "required": True, "cardinality": "exactly_one"}], - ), - _make_schema_row( - "urn:osa:localhost:schema:b@1", - [ - { - "name": "title", - "type": "text", - "required": False, - "cardinality": "exactly_one", - } - ], - ), - ] - _setup_session_result(mock_session, rows) - - reader = PostgresFieldDefinitionReader(session=mock_session) - result = await reader.get_all_field_types() - - assert result == {"title": FieldType.TEXT} diff --git a/server/tests/unit/test_field_ref_resolution.py b/server/tests/unit/test_field_ref_resolution.py deleted file mode 100644 index d647c3a..0000000 --- a/server/tests/unit/test_field_ref_resolution.py +++ /dev/null @@ -1,42 +0,0 @@ -"""US3 tests: typed field-reference parsing and tree validation.""" - -import pytest - -from osa.domain.discovery.model.refs import ( - FeatureFieldRef, - MetadataFieldRef, - parse_field_ref, -) - - -class TestParseFieldRef: - def test_parses_metadata_ref(self): - ref = parse_field_ref("metadata.species") - assert isinstance(ref, MetadataFieldRef) - assert ref.field == "species" - - def test_parses_feature_ref(self): - ref = parse_field_ref("features.cell_classifier.confidence") - assert isinstance(ref, FeatureFieldRef) - assert ref.hook == "cell_classifier" - assert ref.column == "confidence" - - def test_rejects_unknown_prefix(self): - with pytest.raises(ValueError, match="prefix"): - parse_field_ref("other.foo") - - def test_rejects_malformed_metadata(self): - with pytest.raises(ValueError): - parse_field_ref("metadata.a.b") - - def test_rejects_malformed_feature(self): - with pytest.raises(ValueError): - parse_field_ref("features.hook") - - def test_rejects_invalid_identifier(self): - with pytest.raises(ValueError): - parse_field_ref("metadata.Has-Dash") - - def test_dotted_round_trip(self): - assert parse_field_ref("metadata.species").dotted() == "metadata.species" - assert parse_field_ref("features.hook.col").dotted() == "features.hook.col" diff --git a/server/tests/unit/test_filter_expr_and_compile.py b/server/tests/unit/test_filter_expr_and_compile.py deleted file mode 100644 index 41a2ba9..0000000 --- a/server/tests/unit/test_filter_expr_and_compile.py +++ /dev/null @@ -1,197 +0,0 @@ -"""US1 tests: FilterExpr AND-tree validation via DiscoveryService bounds.""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import FeatureFieldRef, MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - FilterOperator, - Predicate, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import SchemaId - - -SCHEMA = SchemaId.parse("bio-sample@1.0.0") - - -def _config(overrides: dict | None = None) -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - cfg = Config() # type: ignore[call-arg] - if overrides: - for k, v in overrides.items(): - setattr(cfg, k, v) - return cfg - - -def _svc( - *, - field_map: dict[str, FieldType] | None = None, - max_depth: int | None = None, - max_preds: int | None = None, - max_joins: int | None = None, -) -> DiscoveryService: - read_store = AsyncMock() - read_store.search_records.return_value = [] - read_store.get_feature_catalog.return_value = [] - - reader = AsyncMock() - fm = field_map or { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - } - reader.get_all_field_types.return_value = fm - reader.get_fields_for_schema.return_value = fm - - overrides = {} - if max_depth is not None: - overrides["discovery_max_filter_depth"] = max_depth - if max_preds is not None: - overrides["discovery_max_predicates"] = max_preds - if max_joins is not None: - overrides["discovery_max_cross_domain_joins"] = max_joins - - return DiscoveryService(read_store=read_store, field_reader=reader, config=_config(overrides)) - - -def _pred(field: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=op, value=value) - - -class TestAndOnlyTrees: - async def test_accepts_and_of_predicates(self) -> None: - svc = _svc() - tree = And( - operands=[ - _pred("title", FilterOperator.EQ, "x"), - _pred("resolution", FilterOperator.GTE, 3.0), - ] - ) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestBoundsEnforced: - async def test_depth_exceeded(self) -> None: - svc = _svc(max_depth=3) - leaf = _pred("title", FilterOperator.EQ, "x") - tree = leaf - for _ in range(4): - tree = And(operands=[tree, leaf]) - - with pytest.raises(ValidationError, match="depth"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_predicates_exceeded(self) -> None: - svc = _svc(max_preds=2) - tree = And( - operands=[ - _pred("title", FilterOperator.EQ, "x"), - _pred("title", FilterOperator.EQ, "y"), - _pred("resolution", FilterOperator.GTE, 3.0), - ] - ) - with pytest.raises(ValidationError, match="predicate leaves"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - async def test_joins_exceeded(self) -> None: - svc = _svc(max_joins=1) - # Simulate catalog advertising multiple hooks with a column named score - svc.read_store.get_feature_catalog.return_value = [ # type: ignore[attr-defined] - type( - "E", - (), - { - "hook_name": "hook_a", - "columns": [ - type("C", (), {"name": "score", "type": "number", "required": True}) - ], - }, - ), - type( - "E", - (), - { - "hook_name": "hook_b", - "columns": [ - type("C", (), {"name": "score", "type": "number", "required": True}) - ], - }, - ), - ] - tree = And( - operands=[ - Predicate( - field=FeatureFieldRef(hook="hook_a", column="score"), - op=FilterOperator.GT, - value=0.0, - ), - Predicate( - field=FeatureFieldRef(hook="hook_b", column="score"), - op=FilterOperator.GT, - value=0.0, - ), - ] - ) - with pytest.raises(ValidationError, match="distinct feature hooks"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) - - -class TestUnknownField: - async def test_unknown_metadata_field_rejected(self) -> None: - svc = _svc() - with pytest.raises(ValidationError, match="Unknown metadata field"): - await svc.search_records( - filter_expr=_pred("bogus", FilterOperator.EQ, "x"), - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=20, - ) diff --git a/server/tests/unit/test_filter_expr_or_not.py b/server/tests/unit/test_filter_expr_or_not.py deleted file mode 100644 index 191bc74..0000000 --- a/server/tests/unit/test_filter_expr_or_not.py +++ /dev/null @@ -1,131 +0,0 @@ -"""US2 tests: FilterExpr accepts OR/NOT trees and validation walks them correctly.""" - -from unittest.mock import AsyncMock - -import pytest - -from osa.config import Config -from osa.domain.discovery.model.refs import MetadataFieldRef -from osa.domain.discovery.model.value import ( - And, - FilterOperator, - Not, - Or, - Predicate, - SortOrder, -) -from osa.domain.discovery.service.discovery import DiscoveryService -from osa.domain.semantics.model.value import FieldType -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import SchemaId - - -SCHEMA = SchemaId.parse("bio-sample@1.0.0") - - -def _config() -> Config: - import os - - os.environ.setdefault("OSA_AUTH__JWT__SECRET", "a" * 64) - os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") - return Config() # type: ignore[call-arg] - - -def _svc() -> DiscoveryService: - read_store = AsyncMock() - read_store.search_records.return_value = [] - reader = AsyncMock() - fm = { - "title": FieldType.TEXT, - "resolution": FieldType.NUMBER, - } - reader.get_all_field_types.return_value = fm - reader.get_fields_for_schema.return_value = fm - return DiscoveryService(read_store=read_store, field_reader=reader, config=_config()) - - -def _pred(field: str, op: FilterOperator, value: object) -> Predicate: - return Predicate(field=MetadataFieldRef(field=field), op=op, value=value) - - -class TestOrNot: - async def test_or_tree_accepted(self): - svc = _svc() - tree = Or( - operands=[ - _pred("title", FilterOperator.EQ, "A"), - _pred("title", FilterOperator.EQ, "B"), - ] - ) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - ) - - async def test_not_tree_accepted(self): - svc = _svc() - tree = Not(operand=_pred("title", FilterOperator.EQ, "X")) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - ) - - async def test_nested_mixed_tree(self): - svc = _svc() - tree = And( - operands=[ - _pred("title", FilterOperator.EQ, "X"), - Or( - operands=[ - _pred("resolution", FilterOperator.GTE, 3.0), - _pred("resolution", FilterOperator.LT, 1.0), - ] - ), - Not(operand=_pred("title", FilterOperator.EQ, "Bad")), - ] - ) - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - ) - - -class TestCompoundDisabledFlag: - async def test_or_rejected_when_compound_disabled(self): - svc = _svc() - tree = Or( - operands=[ - _pred("title", FilterOperator.EQ, "A"), - _pred("title", FilterOperator.EQ, "B"), - ] - ) - with pytest.raises(ValidationError, match="compound_disabled|Compound"): - await svc.search_records( - filter_expr=tree, - schema_id=SCHEMA, - convention_srn=None, - q=None, - sort="published_at", - order=SortOrder.DESC, - cursor=None, - limit=10, - allow_compound=False, - ) diff --git a/server/tests/unit/test_no_index_deps.py b/server/tests/unit/test_no_index_deps.py new file mode 100644 index 0000000..37164ca --- /dev/null +++ b/server/tests/unit/test_no_index_deps.py @@ -0,0 +1,33 @@ +"""US6 footprint guard: the vector-index stack is gone from the image. + +`chromadb` and `sentence-transformers` were the heaviest runtime deps; the +unified `/data/` surface replaced the vector/keyword index domain, so neither +should be importable. If a future change reintroduces them as a transitive +dependency, these tests fail loudly rather than silently re-bloating the image. +""" + +import importlib + +import pytest + + +@pytest.mark.parametrize("module", ["chromadb", "sentence_transformers"]) +def test_index_dependency_not_importable(module: str): + with pytest.raises(ModuleNotFoundError): + importlib.import_module(module) + + +@pytest.mark.parametrize( + "module", + [ + "osa.domain.index", + "osa.domain.search", + "osa.domain.export", + "osa.domain.discovery", + "osa.infrastructure.index", + "osa.sdk.index", + ], +) +def test_deleted_domain_not_importable(module: str): + with pytest.raises(ModuleNotFoundError): + importlib.import_module(module) diff --git a/server/uv.lock b/server/uv.lock index 5953a12..2b9e5c1 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -395,6 +395,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/db/d291e30fdf7ea617a335531e72294e0c723356d7fdde8fba00610a76bda9/coverage-7.13.2-py3-none-any.whl", hash = "sha256:40ce1ea1e25125556d8e76bd0b61500839a07944cc287ac21d5626f3e620cad5", size = 210943, upload-time = "2026-01-25T13:00:02.388Z" }, ] +[[package]] +name = "deprecated" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wrapt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/85/12f0a49a7c4ffb70572b6c2ef13c90c88fd190debda93b23f026b25f9634/deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223", size = 2932523, upload-time = "2025-10-30T08:19:02.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/d0/205d54408c08b13550c733c4b85429e7ead111c7f0014309637425520a9a/deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f", size = 11298, upload-time = "2025-10-30T08:19:00.758Z" }, +] + [[package]] name = "dishka" version = "1.9.1" @@ -677,6 +689,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/b3/a8917d253763095fb8dcaaefc6a135ed31abbd13f681e78752e226e252fe/kubernetes_asyncio-35.0.1-py3-none-any.whl", hash = "sha256:244ef45943e89c5c5104276a646bfcbf1a9dc3d060876c2094aa601e932f1c03", size = 2868606, upload-time = "2026-02-25T20:40:41.191Z" }, ] +[[package]] +name = "limits" +version = "5.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "deprecated", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "packaging", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/69/826a5d1f45426c68d8f6539f8d275c0e4fcaa57f0c017ec3100986558a41/limits-5.8.0.tar.gz", hash = "sha256:c9e0d74aed837e8f6f50d1fcebcf5fd8130957287206bc3799adaee5092655da", size = 226104, upload-time = "2026-02-05T07:17:35.859Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/98/cb5ca20618d205a09d5bec7591fbc4130369c7e6308d9a676a28ff3ab22c/limits-5.8.0-py3-none-any.whl", hash = "sha256:ae1b008a43eb43073c3c579398bd4eb4c795de60952532dc24720ab45e1ac6b8", size = 60954, upload-time = "2026-02-05T07:17:34.425Z" }, +] + [[package]] name = "logfire" version = "4.21.0" @@ -1006,7 +1032,7 @@ wheels = [ [[package]] name = "osa" -version = "0.0.1" +version = "0.0.3" source = { editable = "." } dependencies = [ { name = "aiodocker", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1025,6 +1051,7 @@ dependencies = [ { name = "pydantic-settings", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyjwt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "python-multipart", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "slowapi", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sqlalchemy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvicorn", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] @@ -1069,6 +1096,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "pyjwt", specifier = ">=2.11.0" }, { name = "python-multipart", specifier = ">=0.0.22" }, + { name = "slowapi", specifier = ">=0.1.9" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, { name = "types-aiobotocore-s3", marker = "extra == 'k8s'", specifier = ">=3.3.0" }, { name = "uvicorn", specifier = ">=0.38.0" }, @@ -1477,6 +1505,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "slowapi" +version = "0.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "limits", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a0/99/adfc7f94ca024736f061257d39118e1542bade7a52e86415a4c4ae92d8ff/slowapi-0.1.9.tar.gz", hash = "sha256:639192d0f1ca01b1c6d95bf6c71d794c3a9ee189855337b4821f7f457dddad77", size = 14028, upload-time = "2024-02-05T12:11:52.13Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/bb/f71c4b7d7e7eb3fc1e8c0458a8979b912f40b58002b9fbf37729b8cb464b/slowapi-0.1.9-py3-none-any.whl", hash = "sha256:cfad116cfb84ad9d763ee155c1e5c5cbf00b0d47399a769b227865f5df576e36", size = 14670, upload-time = "2024-02-05T12:11:50.898Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.46" From f290b6a316c03b7862ad1f570a21b4748f62382b Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Tue, 9 Jun 2026 15:37:10 +0100 Subject: [PATCH 02/23] test: add DB-free contract tests for the /data/ surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit US6 deleted the only file in tests/contract/ (an index-coupled worker test), leaving the directory empty. CI's server-contract job runs `pytest tests/contract`, which errors (exit 4) when the path doesn't exist on a fresh checkout. Add a self-sufficient, DB-free contract test asserting the legacy read routes (/discovery, /records/{srn}, /search) now 404 and the new /data/ operation IDs are registered and unique — the SDK-codegen contract — without touching Postgres (the contract job has no DB service). --- .../contract/test_data_surface_contract.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 server/tests/contract/test_data_surface_contract.py diff --git a/server/tests/contract/test_data_surface_contract.py b/server/tests/contract/test_data_surface_contract.py new file mode 100644 index 0000000..7b4982c --- /dev/null +++ b/server/tests/contract/test_data_surface_contract.py @@ -0,0 +1,75 @@ +"""DB-free contract tests for the unified ``/data/`` read surface (US1–US6). + +Runs in CI's ``server-contract`` job, which has **no** Postgres. Everything here +is asserted against routing + the OpenAPI surface without touching the database: + +* the legacy read routes (``/discovery``, ``/records/{srn}``, ``/search``) are + gone — Starlette resolves an unregistered path to 404 before any DI/DB runs; +* the new ``/data/`` operation IDs are registered with stable, unique names + (the SDK-codegen contract). + +DB-backed behaviour of the surface lives in ``tests/integration/test_data_*``. +""" + +import os + +import pytest +from httpx import ASGITransport, AsyncClient + +# create_app() reads Config() at call time; provide the env it needs so this +# test is self-sufficient regardless of how the contract job is configured. +os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") +os.environ.setdefault("OSA_AUTH__JWT__SECRET", "test-secret-for-contract-tests-minimum-32-chars") + + +def _app(): + from osa.application.api.rest.app import create_app + + return create_app() + + +@pytest.fixture +def client() -> AsyncClient: + # No lifespan → no worker pool, no DB connection; pure ASGI routing. + return AsyncClient(transport=ASGITransport(app=_app()), base_url="http://test") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "path", + [ + "/api/v1/discovery/records", + "/api/v1/records/urn:osa:localhost:rec:anything@1", + "/api/v1/search/records?q=anything", + ], +) +async def test_legacy_read_routes_are_gone(client: AsyncClient, path: str): + async with client: + resp = await client.get(path) + assert resp.status_code == 404 + + +def test_data_surface_operation_ids_registered(): + # The factory-minted table op IDs must exist and be unique (SDK codegen). + app = _app() + op_ids = [oid for route in app.routes if (oid := getattr(route, "operation_id", None))] + expected = { + "data_get_node_catalog", + "data_get_record_by_id", + "records_get_json", + "records_post_json", + "records_get_csv", + "records_post_csv", + "records_get_csv_gz", + "records_post_csv_gz", + "feature_get_json", + "feature_post_json", + "feature_get_csv", + "feature_post_csv", + "feature_get_csv_gz", + "feature_post_csv_gz", + } + assert expected <= set(op_ids) + # No duplicates among factory-minted IDs. + factory_ids = [o for o in op_ids if o in expected] + assert len(factory_ids) == len(set(factory_ids)) From d9f50a415144605e3180b0c0fe0d455bc7fdb0b3 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Tue, 9 Jun 2026 18:59:24 +0100 Subject: [PATCH 03/23] fix: map malformed /data/ cursor to 400 and add DataConfig sub-config Greptile review follow-ups on the unified /data/ surface (PR #139): - Malformed pagination cursors (corrupt base64, missing s/id keys, or a non-integer feature id) raised bare ValueError from decode_cursor / int(decoded["id"]) inside the stream's pre-flight pull. ValueError is not an OSAError, so the global handler returned 500. Wrap decode+coerce in both _records_sort and _features_sort and re-raise as ValidationError(field="cursor", code="invalid_cursor") -> 400. - Promote the three flat discovery_max_* config fields into a DataConfig sub-config (config.data.max_filter_depth / max_predicates / max_feature_joins), shedding the dead 'discovery' domain name now that the domain is removed. Overridable via OSA_DATA__MAX_* (nested env). - Break a latent circular import: persistence/__init__ no longer eagerly imports the DI provider, so importing a persistence leaf module (e.g. feature_table) doesn't drag in the whole DI graph. application/di.py imports PersistenceProvider from its module directly. Feature-table schema-scoping concern tracked separately in #140. --- server/osa/application/di.py | 2 +- server/osa/config.py | 19 +++- server/osa/domain/data/service/data_query.py | 13 ++- .../data/postgres_data_read_store.py | 33 +++++-- .../infrastructure/persistence/__init__.py | 9 +- .../domain/data/test_data_query_service.py | 21 ++-- .../unit/infrastructure/data/__init__.py | 0 .../data/test_cursor_validation.py | 95 +++++++++++++++++++ 8 files changed, 164 insertions(+), 28 deletions(-) create mode 100644 server/tests/unit/infrastructure/data/__init__.py create mode 100644 server/tests/unit/infrastructure/data/test_cursor_validation.py diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 6c22eab..a85074a 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -16,7 +16,7 @@ from osa.infrastructure.event.di import EventProvider from osa.infrastructure.http.di import HttpProvider from osa.infrastructure.k8s.di import RunnerProvider -from osa.infrastructure.persistence import PersistenceProvider +from osa.infrastructure.persistence.di import PersistenceProvider from osa.infrastructure.ingest.di import IngestProvider from osa.util.di.scope import Scope from osa.util.paths import OSAPaths diff --git a/server/osa/config.py b/server/osa/config.py index c267c0a..fc44cba 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -260,6 +260,19 @@ def _normalize_pg_url(url: str) -> str: return url +class DataConfig(BaseModel): + """Bounds for the unified ``/data/`` read surface (nested in Config). + + Caps the cost of a filter tree before it is compiled to SQL: maximum tree + depth, total predicate count, and distinct feature-hook joins. Overridable + via ``OSA_DATA__MAX_FILTER_DEPTH`` etc. (FR-012, features 076/137). + """ + + max_filter_depth: int = 10 + max_predicates: int = 200 + max_feature_joins: int = 10 # distinct features. references in one filter + + class Config(BaseSettings): # Archive identity (promoted from Server model) name: str = "Open Science Archive" @@ -283,13 +296,9 @@ class Config(BaseSettings): worker: WorkerConfig = WorkerConfig() # Background worker settings auth: AuthConfig = AuthConfig() # Defaults are dev-safe; boot check enforces prod-correctness runner: RunnerConfig = RunnerConfig() + data: DataConfig = DataConfig() # /data/ read-surface filter-tree bounds host_data_dir: str | None = None # Host path for OSA_DATA_DIR (sibling container mounts) - # Discovery filter-tree bounds (feature 076) - discovery_max_filter_depth: int = 10 - discovery_max_predicates: int = 200 - discovery_max_cross_domain_joins: int = 10 - model_config = { "env_prefix": "OSA_", "env_file": ".env", diff --git a/server/osa/domain/data/service/data_query.py b/server/osa/domain/data/service/data_query.py index 520a562..e5c346d 100644 --- a/server/osa/domain/data/service/data_query.py +++ b/server/osa/domain/data/service/data_query.py @@ -51,26 +51,25 @@ def _validate_filter_bounds(self, expr: FilterExpr | None) -> None: if expr is None: return depth = _tree_depth(expr) - if depth > self.config.discovery_max_filter_depth: + if depth > self.config.data.max_filter_depth: raise ValidationError( - f"Filter tree depth {depth} exceeds maximum " - f"{self.config.discovery_max_filter_depth}.", + f"Filter tree depth {depth} exceeds maximum {self.config.data.max_filter_depth}.", field="filter", code="filter_depth_exceeded", ) predicates = list(_iter_predicates(expr)) - if len(predicates) > self.config.discovery_max_predicates: + if len(predicates) > self.config.data.max_predicates: raise ValidationError( f"Filter tree has {len(predicates)} predicates, exceeds maximum " - f"{self.config.discovery_max_predicates}.", + f"{self.config.data.max_predicates}.", field="filter", code="filter_predicates_exceeded", ) distinct_hooks = {p.field.hook for p in predicates if isinstance(p.field, FeatureFieldRef)} - if len(distinct_hooks) > self.config.discovery_max_cross_domain_joins: + if len(distinct_hooks) > self.config.data.max_feature_joins: raise ValidationError( f"Filter joins {len(distinct_hooks)} feature hooks, exceeds maximum " - f"{self.config.discovery_max_cross_domain_joins}.", + f"{self.config.data.max_feature_joins}.", field="filter", code="filter_joins_exceeded", ) diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index 7a4786e..374c257 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -100,6 +100,19 @@ def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") +def _invalid_cursor(exc: Exception) -> ValidationError: + """Map a cursor decode/coerce ``ValueError`` to a 400, not a 500. + + Decoding a corrupt cursor (bad base64, missing ``s``/``id`` keys) or coercing + a non-conforming value (e.g. a non-integer feature id) raises ``ValueError``, + which is not an ``OSAError`` and would otherwise surface as a generic 500. + The cursor is opaque client-supplied input, so a malformed one is a 400. + """ + return ValidationError( + f"Malformed pagination cursor: {exc}", field="cursor", code="invalid_cursor" + ) + + class PostgresDataReadStore: def __init__(self, session: AsyncSession, node_domain: Domain) -> None: self.session = session @@ -201,9 +214,12 @@ def _records_sort(self, plan: QueryPlan, metadata_table: Any) -> tuple[list[Any] if plan.pagination.cursor is not None: from osa.domain.data.model.query_plan import decode_cursor - decoded = decode_cursor(str(plan.pagination.cursor)) - sort_value = self._coerce_cursor_value(decoded["s"], primary.column) - cursor_after = page.after((sort_value, decoded["id"])) + try: + decoded = decode_cursor(str(plan.pagination.cursor)) + sort_value = self._coerce_cursor_value(decoded["s"], primary.column) + cursor_after = page.after((sort_value, decoded["id"])) + except ValueError as exc: + raise _invalid_cursor(exc) from exc return page.order_by(), cursor_after @staticmethod @@ -319,9 +335,14 @@ def _features_sort( if plan.pagination.cursor is not None: from osa.domain.data.model.query_plan import decode_cursor - decoded = decode_cursor(str(plan.pagination.cursor)) - sort_value = self._coerce_feature_cursor_value(decoded["s"], primary.column, fschema) - cursor_after = page.after((sort_value, int(decoded["id"]))) + try: + decoded = decode_cursor(str(plan.pagination.cursor)) + sort_value = self._coerce_feature_cursor_value( + decoded["s"], primary.column, fschema + ) + cursor_after = page.after((sort_value, int(decoded["id"]))) + except ValueError as exc: + raise _invalid_cursor(exc) from exc return page.order_by(), cursor_after @staticmethod diff --git a/server/osa/infrastructure/persistence/__init__.py b/server/osa/infrastructure/persistence/__init__.py index f2493b6..22af89b 100644 --- a/server/osa/infrastructure/persistence/__init__.py +++ b/server/osa/infrastructure/persistence/__init__.py @@ -1,3 +1,8 @@ -from .di import PersistenceProvider +"""Persistence adapters package. -__all__ = ["PersistenceProvider"] +Intentionally does not re-export ``PersistenceProvider`` at the package root: +importing a leaf module (e.g. ``persistence.feature_table``) must not pull in +the whole DI graph, which imports adapters that in turn import persistence leaf +modules — a circular import. Import the provider from its module directly: +``from osa.infrastructure.persistence.di import PersistenceProvider``. +""" diff --git a/server/tests/unit/domain/data/test_data_query_service.py b/server/tests/unit/domain/data/test_data_query_service.py index ba35c05..f0b0cf4 100644 --- a/server/tests/unit/domain/data/test_data_query_service.py +++ b/server/tests/unit/domain/data/test_data_query_service.py @@ -4,7 +4,7 @@ """ from collections.abc import AsyncIterator, Mapping -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any import pytest @@ -18,13 +18,20 @@ SCHEMA = SchemaId.parse("compound@1.0.0") +@dataclass +class FakeDataConfig: + """The /data/ filter-bound knobs DataQueryService reads (config.data.*).""" + + max_filter_depth: int = 10 + max_predicates: int = 200 + max_feature_joins: int = 10 + + @dataclass class FakeConfig: - """Only the filter-bound knobs DataQueryService reads (avoids full env setup).""" + """Minimal stand-in exposing only ``config.data`` (avoids full env setup).""" - discovery_max_filter_depth: int = 10 - discovery_max_predicates: int = 200 - discovery_max_cross_domain_joins: int = 10 + data: FakeDataConfig = field(default_factory=FakeDataConfig) class FakeReadStore: @@ -83,7 +90,7 @@ async def test_stream_records_rejects_feature_plan() -> None: @pytest.mark.asyncio async def test_filter_depth_bound_enforced() -> None: service, _ = _service() - service.config.discovery_max_filter_depth = 1 + service.config.data.max_filter_depth = 1 deep = And(operands=[_pred(), And(operands=[_pred(), _pred()])]) plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, filter=deep) with pytest.raises(ValidationError) as exc: @@ -94,7 +101,7 @@ async def test_filter_depth_bound_enforced() -> None: @pytest.mark.asyncio async def test_filter_predicate_count_bound_enforced() -> None: service, _ = _service() - service.config.discovery_max_predicates = 1 + service.config.data.max_predicates = 1 plan = QueryPlan( schema_id=SCHEMA, table_kind=TableKind.RECORDS, diff --git a/server/tests/unit/infrastructure/data/__init__.py b/server/tests/unit/infrastructure/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/infrastructure/data/test_cursor_validation.py b/server/tests/unit/infrastructure/data/test_cursor_validation.py new file mode 100644 index 0000000..f3ef10d --- /dev/null +++ b/server/tests/unit/infrastructure/data/test_cursor_validation.py @@ -0,0 +1,95 @@ +"""A malformed pagination cursor must surface as a 400, not a 500. + +``decode_cursor`` raises bare ``ValueError`` on corrupt base64 / missing keys, +and the feature sort coerces ``id`` with ``int(...)`` which also raises +``ValueError`` on a non-integer. Both run inside ``stream_rows``' pre-flight +pull; if they escape as raw ``ValueError`` (not an ``OSAError``) the global +handler returns 500. These tests pin them to ``ValidationError(field="cursor")`` +so the route maps them to 400 — exercised DB-free via the sort helpers, which +decode the cursor before any DB access. +""" + +import base64 +import json + +import pytest + +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortDirection, + SortSpec, + TableKind, + encode_cursor, +) +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.srn import Domain, SchemaId +from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, +) +from osa.infrastructure.persistence.tables import records_table + +SCHEMA = SchemaId.parse("compound@1.0.0") + + +def _store() -> PostgresDataReadStore: + # The sort helpers never touch self.session; decoding happens before any DB + # access, so a placeholder session is sufficient for these unit tests. + return PostgresDataReadStore(None, Domain("localhost")) # type: ignore[arg-type] + + +def _records_plan(cursor: str) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ) + + +def _feature_plan(cursor: str) -> QueryPlan: + return QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="chem_features", + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + + +def _feature_table(): + fschema = FeatureSchema(columns=[ColumnDef(name="score", json_type="number", required=True)]) + return build_feature_table("chem_features", fschema), fschema + + +class TestRecordsCursorValidation: + def test_corrupt_base64_raises_validation_error(self) -> None: + with pytest.raises(ValidationError) as exc: + _store()._records_sort(_records_plan("!!not-base64!!"), records_table) + assert exc.value.field == "cursor" + + def test_missing_keys_raises_validation_error(self) -> None: + # A well-formed base64 JSON object that lacks the required s/id keys. + bad = base64.urlsafe_b64encode(json.dumps({"x": 1}).encode()).decode() + with pytest.raises(ValidationError) as exc: + _store()._records_sort(_records_plan(bad), records_table) + assert exc.value.field == "cursor" + + +class TestFeatureCursorValidation: + def test_corrupt_base64_raises_validation_error(self) -> None: + ft, fschema = _feature_table() + with pytest.raises(ValidationError) as exc: + _store()._features_sort(_feature_plan("!!not-base64!!"), ft, fschema) + assert exc.value.field == "cursor" + + def test_non_integer_id_raises_validation_error(self) -> None: + ft, fschema = _feature_table() + # Well-formed cursor whose id component is not an int → int(...) ValueError. + cursor = encode_cursor(5, "not-an-int") + with pytest.raises(ValidationError) as exc: + _store()._features_sort(_feature_plan(cursor), ft, fschema) + assert exc.value.field == "cursor" From 9c956949ee1ebfde2640101cb64f3c3a8ec03676 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 10 Jun 2026 12:01:41 +0100 Subject: [PATCH 04/23] fix: scope feature-table streams and counts to the requested schema A features. table is shared by every convention that registers the hook name, across schemas (CreateFeatureTables swallows the ConflictError on re-registration). _stream_features and _feature_count read the table unscoped, so /data/{schema}/{feature} leaked rows belonging to other schemas and the manifest row_count was inflated. Join records and filter by schema_id + schema_version in both paths. --- .../data/postgres_data_read_store.py | 33 +++++-- .../test_data_features_postgres.py | 87 ++++++++++++++++--- 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index 374c257..3f40c1e 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -247,13 +247,20 @@ async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, if cursor_after is not None: conditions.append(cursor_after) + # A features. table is shared by every convention that registers + # the hook name, across schemas — scope to the requested schema's + # records via the records join. + conditions.extend(self._feature_schema_scope(plan.schema_id)) + # Implicit columns (id, record_srn, created_at) precede the hook's # declared data columns — this is the CSV header order. select_cols = [ft.c.id, ft.c.record_srn, ft.c.created_at, *data_columns(ft)] - stmt = select(*select_cols).select_from(ft) - if conditions: - stmt = stmt.where(and_(*conditions)) - stmt = stmt.order_by(*order_keys) + stmt = ( + select(*select_cols) + .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) + .where(and_(*conditions)) + .order_by(*order_keys) + ) result = await self.session.stream(stmt) try: @@ -307,8 +314,20 @@ async def _schema_feature_tables(self, schema_id: SchemaId) -> list[tuple[str, F for row in result.mappings() ] - async def _feature_count(self, ft: sa.Table) -> int: - return int((await self.session.execute(select(func.count()).select_from(ft))).scalar_one()) + @staticmethod + def _feature_schema_scope(schema_id: SchemaId) -> list[Any]: + return [ + records_table.c.schema_id == schema_id.id.root, + records_table.c.schema_version == schema_id.version.root, + ] + + async def _feature_count(self, ft: sa.Table, schema_id: SchemaId) -> int: + stmt = ( + select(func.count()) + .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) + .where(and_(*self._feature_schema_scope(schema_id))) + ) + return int((await self.session.execute(stmt)).scalar_one()) def _features_sort( self, plan: QueryPlan, ft: sa.Table, fschema: FeatureSchema @@ -526,7 +545,7 @@ async def _feature_resources(self, schema_id: SchemaId) -> list[TableResource]: # Implicit columns (id, record_srn, created_at) precede the # hook's declared data columns — this is the CSV header order. columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *_feature_column_specs(fschema)], - row_count=await self._feature_count(ft), + row_count=await self._feature_count(ft, schema_id), formats=list(_ALL_FORMATS), ) ) diff --git a/server/tests/integration/test_data_features_postgres.py b/server/tests/integration/test_data_features_postgres.py index ffae4e3..545d18f 100644 --- a/server/tests/integration/test_data_features_postgres.py +++ b/server/tests/integration/test_data_features_postgres.py @@ -22,7 +22,7 @@ from osa.domain.data.model.query_plan import TableKind as TK from osa.domain.semantics.model.schema import Schema from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.error import NotFoundError +from osa.domain.shared.error import ConflictError, NotFoundError from osa.domain.shared.model.hook import ColumnDef from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore @@ -36,6 +36,7 @@ from tests.integration.conftest import seed_record SCHEMA = SchemaId.parse("compound@1.0.0") +SCHEMA_B = SchemaId.parse("protein@1.0.0") HOOK = "chem_features" @@ -57,28 +58,37 @@ def _feature_columns() -> list[ColumnDef]: ] -async def _setup_schema(engine: AsyncEngine, session: AsyncSession) -> PostgresMetadataStore: +async def _setup_schema( + engine: AsyncEngine, session: AsyncSession, schema: SchemaId = SCHEMA +) -> PostgresMetadataStore: store = PostgresMetadataStore(engine, session) - await store.ensure_table(SCHEMA, _fields()) + await store.ensure_table(schema, _fields()) await PostgresSemanticsSchemaRepository(session).save( - Schema(id=SCHEMA, title="compound", fields=_fields(), created_at=datetime.now(UTC)) + Schema(id=schema, title=schema.id.root, fields=_fields(), created_at=datetime.now(UTC)) ) return store -async def _register_hook(engine: AsyncEngine, session: AsyncSession, hook_name: str = HOOK) -> None: +async def _register_hook( + engine: AsyncEngine, + session: AsyncSession, + hook_name: str = HOOK, + schema: SchemaId = SCHEMA, +) -> None: """Link the schema → hook via a convention row, then create its feature table. The read store only reads ``hooks[*].name`` from the convention, so a name-only hooks payload is sufficient to scope the feature to the schema. + A pre-existing feature table is reused, mirroring ``CreateFeatureTables`` + (which swallows the ConflictError when two conventions share a hook name). """ await session.execute( conventions_table.insert().values( - srn=f"urn:osa:localhost:conv:{hook_name}@1.0.0", - title="compound conv", + srn=f"urn:osa:localhost:conv:{schema.id.root}-{hook_name}@1.0.0", + title=f"{schema.id.root} conv", description=None, - schema_id=SCHEMA.id.root, - schema_version=SCHEMA.version.root, + schema_id=schema.id.root, + schema_version=schema.version.root, file_requirements={}, hooks=[{"name": hook_name}], source=None, @@ -87,20 +97,25 @@ async def _register_hook(engine: AsyncEngine, session: AsyncSession, hook_name: ) await session.commit() feature_store = PostgresFeatureStore(engine, session) - await feature_store.create_table(hook_name, _feature_columns()) + try: + await feature_store.create_table(hook_name, _feature_columns()) + except ConflictError: + pass -async def _publish(engine: AsyncEngine, store: PostgresMetadataStore, rid: str) -> RecordSRN: +async def _publish( + engine: AsyncEngine, store: PostgresMetadataStore, rid: str, schema: SchemaId = SCHEMA +) -> RecordSRN: srn = RecordSRN.parse(f"urn:osa:localhost:rec:{rid}@1") await seed_record( engine, srn=str(srn), - schema_id=SCHEMA.id.root, - schema_version=SCHEMA.version.root, + schema_id=schema.id.root, + schema_version=schema.version.root, metadata={"species": "Homo sapiens"}, published_at=datetime(2026, 1, 1, tzinfo=UTC), ) - await store.insert(SCHEMA, srn, {"species": "Homo sapiens"}) + await store.insert(schema, srn, {"species": "Homo sapiens"}) return srn @@ -181,6 +196,50 @@ async def test_unknown_feature_raises_not_found( await _drain(rs, plan) +@pytest.mark.asyncio +class TestFeatureSchemaScoping: + """Two schemas whose conventions register the same hook name share one + physical ``features.`` table (``CreateFeatureTables`` swallows the + ConflictError). Reads at ``/data/{schema}/{feature}`` must still be scoped + to records of the requested schema.""" + + async def _seed_shared_hook( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ) -> tuple[RecordSRN, RecordSRN]: + store_a = await _setup_schema(pg_engine, pg_session) + store_b = await _setup_schema(pg_engine, pg_session, schema=SCHEMA_B) + srn_a = await _publish(pg_engine, store_a, "reca") + srn_b = await _publish(pg_engine, store_b, "recb", schema=SCHEMA_B) + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + await _register_hook(pg_engine, pg_session, schema=SCHEMA_B) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features(HOOK, str(srn_a), [{"score": 0.9, "label": "a"}]) + await feature_store.insert_features(HOOK, str(srn_b), [{"score": 0.2, "label": "b"}]) + return srn_a, srn_b + + async def test_stream_excludes_rows_of_other_schemas( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + srn_a, _ = await self._seed_shared_hook(pg_engine, pg_session) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rows = await _drain(rs, _feature_plan()) + + assert [r["record_srn"] for r in rows] == [str(srn_a)] + + async def test_manifest_row_count_scoped_to_schema( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + await self._seed_shared_hook(pg_engine, pg_session) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + manifest = await rs.get_schema_manifest(SCHEMA) + assert manifest is not None + feature_res = next(t for t in manifest.table_resources if t.name == HOOK) + assert feature_res.row_count == 1 + + @pytest.mark.asyncio class TestFeatureCatalogAndManifest: async def test_manifest_includes_feature_resource( From 1d798049e7d19f5e4d048bade41fc4544df9ef14 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 10 Jun 2026 12:01:49 +0100 Subject: [PATCH 05/23] refactor: drop dead SortOrder enum from data filter model --- server/osa/domain/data/model/filter.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/server/osa/domain/data/model/filter.py b/server/osa/domain/data/model/filter.py index bacea17..02d9708 100644 --- a/server/osa/domain/data/model/filter.py +++ b/server/osa/domain/data/model/filter.py @@ -104,11 +104,6 @@ class FilterOperator(StrEnum): IS_NULL = "is_null" -class SortOrder(StrEnum): - ASC = "asc" - DESC = "desc" - - FieldRef = Annotated[ Union[MetadataFieldRef, FeatureFieldRef], Field(discriminator="path"), From 8c65980dd4ef7f89c7cdc1451336601eebef33a4 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 10 Jun 2026 12:01:56 +0100 Subject: [PATCH 06/23] build: allow narrowing test-integration-pg to a pytest target --- server/Justfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/Justfile b/server/Justfile index cf80594..78ac928 100644 --- a/server/Justfile +++ b/server/Justfile @@ -106,7 +106,8 @@ test-integration: @TEST=1 uv run pytest tests/integration -v --tb=short -x # Run integration tests with PG: ensure DB running → wipe → create → migrate → test → wipe -test-integration-pg: +# Optionally narrow to a pytest target, e.g. `just test-integration-pg tests/integration/test_foo.py` +test-integration-pg target="tests/integration": POSTGRES_PORT={{TEST_PG_PORT}} just --justfile ../Justfile db-up @just test-db-drop @just test-db-create @@ -116,7 +117,7 @@ test-integration-pg: OSA_DATABASE__URL="{{TEST_DB_URL}}" \ OSA_AUTH__JWT__SECRET="test-secret-for-integration-tests-minimum-32-chars" \ TEST=1 \ - uv run pytest tests/integration -v --tb=short -x; \ + uv run pytest "{{target}}" -v --tb=short -x; \ just test-db-drop # === Dependencies === From a614c97c6a6e1e4c4725477aa00c953bfcf8efe0 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 10 Jun 2026 14:32:28 +0100 Subject: [PATCH 07/23] fix: make created_at feature cursors round-trip instead of 500ing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Feature rows are raw DB mappings, so paginating a feature table sorted by created_at handed a Python datetime to encode_cursor, where bare json.dumps raised TypeError -> 500. Encode with default=str (as the JSON serializer already does), and coerce the decoded ISO string back to a datetime in _coerce_feature_cursor_value — created_at is an implicit column with no ColumnDef, so the existing format-driven coercion missed it and asyncpg would reject the str bind against DateTime(timezone=True). --- server/osa/domain/data/model/query_plan.py | 9 +++- .../data/postgres_data_read_store.py | 4 ++ .../test_data_features_postgres.py | 48 +++++++++++++++++++ .../tests/unit/domain/data/test_query_plan.py | 11 +++++ .../data/test_cursor_validation.py | 11 +++++ 5 files changed, 81 insertions(+), 2 deletions(-) diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py index 5a54adc..d193e82 100644 --- a/server/osa/domain/data/model/query_plan.py +++ b/server/osa/domain/data/model/query_plan.py @@ -88,9 +88,14 @@ def _validate_and_default(self) -> "QueryPlan": def encode_cursor(sort_value: Any, id_value: Any) -> str: - """Encode a cursor as urlsafe base64 of ``{"s": sort_value, "id": id_value}``.""" + """Encode a cursor as urlsafe base64 of ``{"s": sort_value, "id": id_value}``. + + ``default=str`` because feature rows are raw DB mappings — a datetime sort + value (``created_at``, or a date-time hook column) arrives unrendered. The + sort decoders coerce the ISO string back to a datetime before binding. + """ payload = {"s": sort_value, "id": id_value} - return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() + return base64.urlsafe_b64encode(json.dumps(payload, default=str).encode()).decode() def decode_cursor(cursor: str) -> dict[str, Any]: diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index 3f40c1e..973ff79 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -368,6 +368,10 @@ def _features_sort( def _coerce_feature_cursor_value(value: Any, column: str, fschema: FeatureSchema) -> Any: if column == "id": return None if value is None else int(value) + # created_at is an implicit column (no ColumnDef in the hook's schema); + # the cursor carries it as an ISO string but the bind needs a datetime. + if column == "created_at" and isinstance(value, str): + return datetime.fromisoformat(value) col_def = next((c for c in fschema.columns if c.name == column), None) if col_def is not None and col_def.json_type == "string" and isinstance(value, str): if col_def.format == "date-time": diff --git a/server/tests/integration/test_data_features_postgres.py b/server/tests/integration/test_data_features_postgres.py index 545d18f..8cca997 100644 --- a/server/tests/integration/test_data_features_postgres.py +++ b/server/tests/integration/test_data_features_postgres.py @@ -15,9 +15,13 @@ from osa.domain.data.model.filter import FilterOperator, FeatureFieldRef, Predicate from osa.domain.data.model.query_plan import ( + PaginationCursor, PaginationParams, QueryPlan, + SortDirection, + SortSpec, TableKind, + encode_cursor, ) from osa.domain.data.model.query_plan import TableKind as TK from osa.domain.semantics.model.schema import Schema @@ -196,6 +200,50 @@ async def test_unknown_feature_raises_not_found( await _drain(rs, plan) +@pytest.mark.asyncio +class TestFeatureCreatedAtCursor: + """``?sort=created_at`` pagination on a feature table must round-trip. + + Feature rows are raw DB mappings, so the route layer encodes the next-page + cursor from a Python ``datetime`` (encoder must serialize it), and the + second page binds the decoded value against ``DateTime(timezone=True)`` + (decoder must coerce the ISO string back — created_at is an implicit + column with no ColumnDef in the hook's FeatureSchema).""" + + async def test_created_at_sort_cursor_round_trips( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + store = await _setup_schema(pg_engine, pg_session) + srn = await _publish(pg_engine, store, "rec1") + await pg_session.commit() + await _register_hook(pg_engine, pg_session) + feature_store = PostgresFeatureStore(pg_engine, pg_session) + await feature_store.insert_features(HOOK, str(srn), [{"score": s} for s in (0.1, 0.2, 0.3)]) + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + sort = [SortSpec(column="created_at", direction=SortDirection.ASC)] + all_rows = await _drain( + rs, + QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name=HOOK, sort=sort), + ) + assert len(all_rows) == 3 + + # Encode the cursor exactly as routes/data/_streaming.py does: straight + # from the row mapping, where created_at is a datetime. + cursor = encode_cursor(all_rows[1]["created_at"], all_rows[1]["id"]) + page2 = await _drain( + rs, + QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name=HOOK, + sort=sort, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ), + ) + assert [r["id"] for r in page2] == [all_rows[2]["id"]] + + @pytest.mark.asyncio class TestFeatureSchemaScoping: """Two schemas whose conventions register the same hook name share one diff --git a/server/tests/unit/domain/data/test_query_plan.py b/server/tests/unit/domain/data/test_query_plan.py index 0690d30..3aabf74 100644 --- a/server/tests/unit/domain/data/test_query_plan.py +++ b/server/tests/unit/domain/data/test_query_plan.py @@ -1,5 +1,7 @@ """T029 — QueryPlan validation rules and per-kind default sorts.""" +from datetime import UTC, datetime + import pytest from pydantic import ValidationError as PydanticValidationError @@ -59,3 +61,12 @@ def test_cursor_roundtrip() -> None: def test_decode_malformed_cursor_raises() -> None: with pytest.raises(ValueError): decode_cursor("not-valid-base64-json!!!") + + +def test_encode_cursor_accepts_datetime_sort_value() -> None: + # Feature rows reach the cursor encoder as raw DB mappings, so a + # created_at sort value is a datetime, not a pre-rendered string. + ts = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + decoded = decode_cursor(encode_cursor(ts, 7)) + assert datetime.fromisoformat(decoded["s"]) == ts + assert decoded["id"] == 7 diff --git a/server/tests/unit/infrastructure/data/test_cursor_validation.py b/server/tests/unit/infrastructure/data/test_cursor_validation.py index f3ef10d..c819f21 100644 --- a/server/tests/unit/infrastructure/data/test_cursor_validation.py +++ b/server/tests/unit/infrastructure/data/test_cursor_validation.py @@ -11,6 +11,7 @@ import base64 import json +from datetime import datetime import pytest @@ -93,3 +94,13 @@ def test_non_integer_id_raises_validation_error(self) -> None: with pytest.raises(ValidationError) as exc: _store()._features_sort(_feature_plan(cursor), ft, fschema) assert exc.value.field == "cursor" + + def test_created_at_cursor_value_coerced_to_datetime(self) -> None: + # created_at is an implicit feature column (no ColumnDef in the hook's + # FeatureSchema), so the decoder must coerce its ISO string itself — + # asyncpg rejects a str bound against DateTime(timezone=True). + _, fschema = _feature_table() + value = PostgresDataReadStore._coerce_feature_cursor_value( + "2026-01-02T03:04:05+00:00", "created_at", fschema + ) + assert isinstance(value, datetime) From 57c4f4a03abfef358bb913fd421491d4abc7353a Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 11:34:52 +0100 Subject: [PATCH 08/23] fix: coerce records cursor values by sort-column type, not column name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _coerce_cursor_value only recognised created_at/published_at by name, so paginating records sorted by a FieldType.DATE (or date-time) metadata column bound the cursor's ISO string against the dynamic table's sa.Date/sa.DateTime column — asyncpg rejects the str bind on page 2+. Dispatch on the resolved sort expression's SQLAlchemy type instead, which also covers the old name-based cases. --- .../data/postgres_data_read_store.py | 18 +++-- .../test_data_read_store_postgres.py | 66 ++++++++++++++++++- .../data/test_cursor_validation.py | 18 ++++- 3 files changed, 95 insertions(+), 7 deletions(-) diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index 973ff79..8e42691 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -216,16 +216,26 @@ def _records_sort(self, plan: QueryPlan, metadata_table: Any) -> tuple[list[Any] try: decoded = decode_cursor(str(plan.pagination.cursor)) - sort_value = self._coerce_cursor_value(decoded["s"], primary.column) + sort_value = self._coerce_cursor_value(decoded["s"], sort_expr) cursor_after = page.after((sort_value, decoded["id"])) except ValueError as exc: raise _invalid_cursor(exc) from exc return page.order_by(), cursor_after @staticmethod - def _coerce_cursor_value(value: Any, column: str) -> Any: - if column in ("created_at", "published_at") and isinstance(value, str): - return datetime.fromisoformat(value) + def _coerce_cursor_value(value: Any, sort_expr: sa.ColumnElement[Any]) -> Any: + """Coerce a decoded cursor sort value to the sort column's Python type. + + Cursors carry date/datetime values as ISO strings (``encode_cursor`` + renders with ``default=str``), but asyncpg rejects a str bound against + a DATE / TIMESTAMP column — including dynamic metadata columns of + FieldType.DATE, so dispatch on the column type, not the column name. + """ + if isinstance(value, str): + if isinstance(sort_expr.type, sa.DateTime): + return datetime.fromisoformat(value) + if isinstance(sort_expr.type, sa.Date): + return date.fromisoformat(value) return value # ------------------------------------------------------------------ # diff --git a/server/tests/integration/test_data_read_store_postgres.py b/server/tests/integration/test_data_read_store_postgres.py index 3882eaa..77111fb 100644 --- a/server/tests/integration/test_data_read_store_postgres.py +++ b/server/tests/integration/test_data_read_store_postgres.py @@ -15,11 +15,13 @@ from osa.domain.data.model.filter import FilterOperator, MetadataFieldRef, Predicate from osa.domain.data.model.query_plan import ( + PaginationCursor, PaginationParams, QueryPlan, SortDirection, SortSpec, TableKind, + encode_cursor, ) from osa.domain.data.model.query_plan import TableKind as TK from osa.domain.semantics.model.schema import Schema @@ -248,8 +250,6 @@ async def test_cursor_advances_without_overlap( rs = PostgresDataReadStore(pg_session, Domain("localhost")) # Sort by created_at desc (default). Page 1: take 2, derive cursor from row 2. - from osa.domain.data.model.query_plan import PaginationCursor, encode_cursor - all_rows = await _drain(rs, _records_plan()) assert [r["id"] for r in all_rows] == ["rec4", "rec3", "rec2", "rec1", "rec0"] @@ -263,3 +263,65 @@ async def test_cursor_advances_without_overlap( ) page2 = await _drain(rs, plan2) assert [r["id"] for r in page2] == ["rec2", "rec1", "rec0"] + + +ASSAY_SCHEMA = SchemaId.parse("assay@1.0.0") + + +@pytest.mark.asyncio +class TestMetadataDateCursor: + """Pagination sorted by a FieldType.DATE metadata column must round-trip. + + The cursor carries the sort value as an ISO string; the records sort must + coerce it back to a Python date before binding — asyncpg rejects a str + bound against the dynamic table's sa.Date column.""" + + async def test_date_sort_cursor_round_trips( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + fields = [ + FieldDefinition( + name="assay_date", + type=FieldType.DATE, + required=True, + cardinality=Cardinality.EXACTLY_ONE, + ), + ] + store = PostgresMetadataStore(pg_engine, pg_session) + await store.ensure_table(ASSAY_SCHEMA, fields) + await PostgresSemanticsSchemaRepository(pg_session).save( + Schema(id=ASSAY_SCHEMA, title="assay", fields=fields, created_at=datetime.now(UTC)) + ) + for i in range(3): + srn = RecordSRN.parse(f"urn:osa:localhost:rec:arec{i}@1") + await seed_record( + pg_engine, + srn=str(srn), + schema_id=ASSAY_SCHEMA.id.root, + schema_version=ASSAY_SCHEMA.version.root, + metadata={"assay_date": f"2026-01-0{i + 1}"}, + published_at=datetime(2026, 1, 1, tzinfo=UTC), + ) + await store.insert(ASSAY_SCHEMA, srn, {"assay_date": f"2026-01-0{i + 1}"}) + await pg_session.commit() + + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + sort = [SortSpec(column="assay_date", direction=SortDirection.ASC)] + all_rows = await _drain( + rs, QueryPlan(schema_id=ASSAY_SCHEMA, table_kind=TableKind.RECORDS, sort=sort) + ) + assert [r["id"] for r in all_rows] == ["arec0", "arec1", "arec2"] + + # Encode the cursor as the route layer does — the date sort value + # serializes to an ISO string via default=str. + cursor = encode_cursor(all_rows[1]["assay_date"], all_rows[1]["srn"]) + page2 = await _drain( + rs, + QueryPlan( + schema_id=ASSAY_SCHEMA, + table_kind=TableKind.RECORDS, + sort=sort, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ), + ) + assert [r["id"] for r in page2] == ["arec2"] diff --git a/server/tests/unit/infrastructure/data/test_cursor_validation.py b/server/tests/unit/infrastructure/data/test_cursor_validation.py index c819f21..4c189dc 100644 --- a/server/tests/unit/infrastructure/data/test_cursor_validation.py +++ b/server/tests/unit/infrastructure/data/test_cursor_validation.py @@ -11,9 +11,10 @@ import base64 import json -from datetime import datetime +from datetime import date, datetime import pytest +import sqlalchemy as sa from osa.domain.data.model.query_plan import ( PaginationCursor, @@ -79,6 +80,21 @@ def test_missing_keys_raises_validation_error(self) -> None: _store()._records_sort(_records_plan(bad), records_table) assert exc.value.field == "cursor" + def test_date_metadata_cursor_value_coerced_to_date(self) -> None: + # A FieldType.DATE metadata column is a sa.Date in the dynamic table; + # the cursor carries its value as an ISO string — asyncpg rejects a + # str bound against DATE, so the decoder must coerce by column type. + value = PostgresDataReadStore._coerce_cursor_value( + "2026-01-02", sa.Column("assay_date", sa.Date()) + ) + assert isinstance(value, date) + + def test_datetime_metadata_cursor_value_coerced_to_datetime(self) -> None: + value = PostgresDataReadStore._coerce_cursor_value( + "2026-01-02T03:04:05+00:00", sa.Column("measured_at", sa.DateTime(timezone=True)) + ) + assert isinstance(value, datetime) + class TestFeatureCursorValidation: def test_corrupt_base64_raises_validation_error(self) -> None: From 0c08fe2b422a651ece0641aba45dfa2b53e6bdd0 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 11:41:07 +0100 Subject: [PATCH 09/23] refactor: unify cursor coercion on column types; share the keyset-after path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The records and feature sorts each had their own cursor handling: a type-driven coercer on the records side, a FeatureSchema-driven one on the feature side (which needed a name-based special case for the implicit created_at column), and two copies of the decode/coerce/after block with function-local decode_cursor imports. Collapse to one _cursor_after helper and one _coerce_cursor_value that dispatches on the bound column's SQLAlchemy type — covering DATE/TIMESTAMP/BIGINT for both table kinds — and drop the now-dead fschema parameter from _features_sort. --- .../data/postgres_data_read_store.py | 106 +++++++++--------- .../data/test_cursor_validation.py | 27 +++-- 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index 8e42691..cb1258f 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -56,7 +56,12 @@ SchemaManifest, TableResource, ) -from osa.domain.data.model.query_plan import QueryPlan, SortDirection, TableKind +from osa.domain.data.model.query_plan import ( + QueryPlan, + SortDirection, + TableKind, + decode_cursor, +) from osa.domain.data.model.record_summary import RecordSummary from osa.domain.semantics.model.value import FieldType from osa.domain.shared.error import NotFoundError, ValidationError @@ -210,32 +215,53 @@ def _records_sort(self, plan: QueryPlan, metadata_table: Any) -> tuple[list[Any] SortKey(t.c.srn, descending=is_desc), ] ) - cursor_after = None - if plan.pagination.cursor is not None: - from osa.domain.data.model.query_plan import decode_cursor - - try: - decoded = decode_cursor(str(plan.pagination.cursor)) - sort_value = self._coerce_cursor_value(decoded["s"], sort_expr) - cursor_after = page.after((sort_value, decoded["id"])) - except ValueError as exc: - raise _invalid_cursor(exc) from exc - return page.order_by(), cursor_after + return page.order_by(), self._cursor_after(plan, page, sort_expr, t.c.srn) + + def _cursor_after( + self, + plan: QueryPlan, + page: KeysetPage, + sort_expr: sa.ColumnElement[Any], + id_expr: sa.ColumnElement[Any], + ) -> Any | None: + """Build the keyset ``after`` condition from the plan's opaque cursor. + + Shared by the records and feature sorts: both cursor components are + coerced to their column's Python type; a non-conforming value raises + ``ValueError`` → 400 (the cursor is client-supplied input). + """ + if plan.pagination.cursor is None: + return None + try: + decoded = decode_cursor(str(plan.pagination.cursor)) + return page.after( + ( + self._coerce_cursor_value(decoded["s"], sort_expr), + self._coerce_cursor_value(decoded["id"], id_expr), + ) + ) + except ValueError as exc: + raise _invalid_cursor(exc) from exc @staticmethod def _coerce_cursor_value(value: Any, sort_expr: sa.ColumnElement[Any]) -> Any: - """Coerce a decoded cursor sort value to the sort column's Python type. + """Coerce a decoded cursor value to the bound column's Python type. Cursors carry date/datetime values as ISO strings (``encode_cursor`` - renders with ``default=str``), but asyncpg rejects a str bound against - a DATE / TIMESTAMP column — including dynamic metadata columns of - FieldType.DATE, so dispatch on the column type, not the column name. + renders with ``default=str``), but asyncpg rejects mistyped binds — + str vs DATE / TIMESTAMP (including dynamic metadata columns of + FieldType.DATE), str vs BIGINT — so dispatch on the column type, not + the column name. """ - if isinstance(value, str): - if isinstance(sort_expr.type, sa.DateTime): - return datetime.fromisoformat(value) - if isinstance(sort_expr.type, sa.Date): - return date.fromisoformat(value) + if value is None: + return None + col_type = sort_expr.type + if isinstance(col_type, sa.DateTime) and isinstance(value, str): + return datetime.fromisoformat(value) + if isinstance(col_type, sa.Date) and isinstance(value, str): + return date.fromisoformat(value) + if isinstance(col_type, sa.Integer): + return int(value) return value # ------------------------------------------------------------------ # @@ -245,7 +271,7 @@ def _coerce_cursor_value(value: Any, sort_expr: sa.ColumnElement[Any]) -> Any: async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: if plan.feature_name is None: # guarded by QueryPlan, narrowed for the type checker raise ValidationError("feature_name is required for a FEATURE plan", field="feature") - ft, fschema = await self._resolve_feature_table(plan.schema_id, plan.feature_name) + ft, _ = await self._resolve_feature_table(plan.schema_id, plan.feature_name) conditions: list[Any] = [] if plan.filter is not None: @@ -253,7 +279,7 @@ async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, self._compile_feature_filter(plan.filter, ft=ft, feature_name=plan.feature_name) ) - order_keys, cursor_after = self._features_sort(plan, ft, fschema) + order_keys, cursor_after = self._features_sort(plan, ft) if cursor_after is not None: conditions.append(cursor_after) @@ -339,9 +365,7 @@ async def _feature_count(self, ft: sa.Table, schema_id: SchemaId) -> int: ) return int((await self.session.execute(stmt)).scalar_one()) - def _features_sort( - self, plan: QueryPlan, ft: sa.Table, fschema: FeatureSchema - ) -> tuple[list[Any], Any | None]: + def _features_sort(self, plan: QueryPlan, ft: sa.Table) -> tuple[list[Any], Any | None]: primary = plan.sort[0] is_desc = primary.direction == SortDirection.DESC if primary.column == "id": @@ -360,35 +384,7 @@ def _features_sort( SortKey(ft.c.id, descending=is_desc), ] ) - cursor_after = None - if plan.pagination.cursor is not None: - from osa.domain.data.model.query_plan import decode_cursor - - try: - decoded = decode_cursor(str(plan.pagination.cursor)) - sort_value = self._coerce_feature_cursor_value( - decoded["s"], primary.column, fschema - ) - cursor_after = page.after((sort_value, int(decoded["id"]))) - except ValueError as exc: - raise _invalid_cursor(exc) from exc - return page.order_by(), cursor_after - - @staticmethod - def _coerce_feature_cursor_value(value: Any, column: str, fschema: FeatureSchema) -> Any: - if column == "id": - return None if value is None else int(value) - # created_at is an implicit column (no ColumnDef in the hook's schema); - # the cursor carries it as an ISO string but the bind needs a datetime. - if column == "created_at" and isinstance(value, str): - return datetime.fromisoformat(value) - col_def = next((c for c in fschema.columns if c.name == column), None) - if col_def is not None and col_def.json_type == "string" and isinstance(value, str): - if col_def.format == "date-time": - return datetime.fromisoformat(value) - if col_def.format == "date": - return date.fromisoformat(value) - return value + return page.order_by(), self._cursor_after(plan, page, sort_expr, ft.c.id) def _compile_feature_filter(self, expr: FilterExpr, *, ft: sa.Table, feature_name: str) -> Any: if isinstance(expr, Predicate): diff --git a/server/tests/unit/infrastructure/data/test_cursor_validation.py b/server/tests/unit/infrastructure/data/test_cursor_validation.py index 4c189dc..b8eb655 100644 --- a/server/tests/unit/infrastructure/data/test_cursor_validation.py +++ b/server/tests/unit/infrastructure/data/test_cursor_validation.py @@ -64,7 +64,7 @@ def _feature_plan(cursor: str) -> QueryPlan: def _feature_table(): fschema = FeatureSchema(columns=[ColumnDef(name="score", json_type="number", required=True)]) - return build_feature_table("chem_features", fschema), fschema + return build_feature_table("chem_features", fschema) class TestRecordsCursorValidation: @@ -98,25 +98,30 @@ def test_datetime_metadata_cursor_value_coerced_to_datetime(self) -> None: class TestFeatureCursorValidation: def test_corrupt_base64_raises_validation_error(self) -> None: - ft, fschema = _feature_table() + ft = _feature_table() with pytest.raises(ValidationError) as exc: - _store()._features_sort(_feature_plan("!!not-base64!!"), ft, fschema) + _store()._features_sort(_feature_plan("!!not-base64!!"), ft) assert exc.value.field == "cursor" def test_non_integer_id_raises_validation_error(self) -> None: - ft, fschema = _feature_table() + ft = _feature_table() # Well-formed cursor whose id component is not an int → int(...) ValueError. cursor = encode_cursor(5, "not-an-int") with pytest.raises(ValidationError) as exc: - _store()._features_sort(_feature_plan(cursor), ft, fschema) + _store()._features_sort(_feature_plan(cursor), ft) assert exc.value.field == "cursor" def test_created_at_cursor_value_coerced_to_datetime(self) -> None: - # created_at is an implicit feature column (no ColumnDef in the hook's - # FeatureSchema), so the decoder must coerce its ISO string itself — - # asyncpg rejects a str bound against DateTime(timezone=True). - _, fschema = _feature_table() - value = PostgresDataReadStore._coerce_feature_cursor_value( - "2026-01-02T03:04:05+00:00", "created_at", fschema + # created_at is an implicit feature column; the type-driven coercion + # must turn its ISO string back into a datetime before binding. + ft = _feature_table() + value = PostgresDataReadStore._coerce_cursor_value( + "2026-01-02T03:04:05+00:00", ft.c.created_at ) assert isinstance(value, datetime) + + def test_integer_id_sort_value_coerced_to_int(self) -> None: + # Sorting by id, the cursor sort value binds against BIGINT — a str + # is coerced (and garbage raises ValueError → 400 via the sort path). + ft = _feature_table() + assert PostgresDataReadStore._coerce_cursor_value("7", ft.c.id) == 7 From 3f6b434aa928e0e51f9fca759e2681cef977418d Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 12:11:31 +0100 Subject: [PATCH 10/23] fix: encode the srn as the cursor sort value for records sorted by id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sort=id on records aliases to the srn column in the store, but _next_cursor encoded the flattened row's bare id as the cursor's sort value. The keyset condition compared srn > '', which every SRN satisfies (they all start with 'urn:'), so next_cursor kept regenerating and pagination never advanced. Encode the tiebreaker value (records srn / feature id) as the sort value whenever the sort column is id — the store compares both cursor components against that same column. --- .../api/v1/routes/data/_streaming.py | 12 +++-- .../test_data_read_store_postgres.py | 37 +++++++++++++ .../api/v1/routes/data/test_streaming.py | 54 ++++++++++++++++++- 3 files changed, 98 insertions(+), 5 deletions(-) diff --git a/server/osa/application/api/v1/routes/data/_streaming.py b/server/osa/application/api/v1/routes/data/_streaming.py index 7325719..82c6bf0 100644 --- a/server/osa/application/api/v1/routes/data/_streaming.py +++ b/server/osa/application/api/v1/routes/data/_streaming.py @@ -94,8 +94,12 @@ def _next_cursor(page: list[Mapping[str, Any]], plan: QueryPlan) -> str | None: if not page: return None last = page[-1] + # The keyset tiebreaker is the records srn / feature id (see + # PostgresDataReadStore). ``sort=id`` aliases to that same column in the + # store, so the cursor's sort value must be the tiebreaker value too — + # encoding the bare record id would compare it against the srn column, + # match every row, and never advance. + tiebreak = last.get("srn", last.get("id")) sort_column = plan.sort[0].column - sort_value = last.get(sort_column) - # The keyset tiebreaker is the record SRN (see PostgresDataReadStore), so the - # cursor's id component is the full srn, not the bare id. - return encode_cursor(sort_value, last.get("srn", last.get("id"))) + sort_value = tiebreak if sort_column == "id" else last.get(sort_column) + return encode_cursor(sort_value, tiebreak) diff --git a/server/tests/integration/test_data_read_store_postgres.py b/server/tests/integration/test_data_read_store_postgres.py index 77111fb..ebb1405 100644 --- a/server/tests/integration/test_data_read_store_postgres.py +++ b/server/tests/integration/test_data_read_store_postgres.py @@ -264,6 +264,43 @@ async def test_cursor_advances_without_overlap( page2 = await _drain(rs, plan2) assert [r["id"] for r in page2] == ["rec2", "rec1", "rec0"] + async def test_id_sort_cursor_round_trips( + self, pg_engine: AsyncEngine, pg_session: AsyncSession + ): + # ``sort=id`` aliases to the srn column; the route layer encodes the + # srn as the cursor sort value (see _next_cursor). A bare-id cursor + # would sort before every SRN and pagination would never advance. + store = await _setup_schema(pg_engine, pg_session) + for i in range(3): + await _publish( + pg_engine, + store, + f"rec{i}", + "Homo sapiens", + float(i), + datetime(2026, 1, 1, tzinfo=UTC), + ) + await pg_session.commit() + rs = PostgresDataReadStore(pg_session, Domain("localhost")) + + sort = [SortSpec(column="id", direction=SortDirection.ASC)] + all_rows = await _drain( + rs, QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, sort=sort) + ) + assert [r["id"] for r in all_rows] == ["rec0", "rec1", "rec2"] + + cursor = encode_cursor(all_rows[1]["srn"], all_rows[1]["srn"]) + page2 = await _drain( + rs, + QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + sort=sort, + pagination=PaginationParams(cursor=PaginationCursor(value=cursor)), + ), + ) + assert [r["id"] for r in page2] == ["rec2"] + ASSAY_SCHEMA = SchemaId.parse("assay@1.0.0") diff --git a/server/tests/unit/application/api/v1/routes/data/test_streaming.py b/server/tests/unit/application/api/v1/routes/data/test_streaming.py index b1e2c97..73a18de 100644 --- a/server/tests/unit/application/api/v1/routes/data/test_streaming.py +++ b/server/tests/unit/application/api/v1/routes/data/test_streaming.py @@ -13,7 +13,12 @@ from osa.application.api.v1.routes.data._streaming import build_table_response from osa.domain.data.model.format import FORMATS from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.model.query_plan import ( + QueryPlan, + SortDirection, + SortSpec, + TableKind, +) from osa.domain.semantics.model.value import FieldType from osa.domain.shared.model.srn import SchemaId @@ -101,3 +106,50 @@ async def test_paginated_has_more_emits_cursor() -> None: # default RECORDS sort is created_at desc, tiebreaker srn assert decoded["s"] == "2026-01-02" assert decoded["id"] == "s2" + + +@pytest.mark.asyncio +async def test_paginated_records_id_sort_encodes_srn() -> None: + # ``sort=id`` on records aliases to the srn column in the store, so the + # cursor's sort value must be the srn too. Encoding the bare id would + # compare it against the srn column — every SRN sorts after a bare id, so + # the keyset matches all rows and pagination never advances. + rows = [ + {"id": "a", "srn": "urn:osa:localhost:rec:a@1"}, + {"id": "b", "srn": "urn:osa:localhost:rec:b@1"}, + {"id": "c", "srn": "urn:osa:localhost:rec:c@1"}, + ] + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + pagination={"limit": 2}, + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, plan) + parsed = json.loads(await _body(resp)) + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + assert decoded["s"] == "urn:osa:localhost:rec:b@1" + assert decoded["id"] == "urn:osa:localhost:rec:b@1" + + +@pytest.mark.asyncio +async def test_paginated_features_id_sort_encodes_row_id() -> None: + # Feature rows have a real integer id column (no srn key), so ``sort=id`` + # keeps encoding the row's own id. + rows = [ + {"id": 1, "record_srn": "urn:osa:localhost:rec:a@1"}, + {"id": 2, "record_srn": "urn:osa:localhost:rec:a@1"}, + {"id": 3, "record_srn": "urn:osa:localhost:rec:a@1"}, + ] + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="chem_features", + pagination={"limit": 2}, + # default FEATURE sort is id asc + ) + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, plan) + parsed = json.loads(await _body(resp)) + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + assert decoded["s"] == 2 + assert decoded["id"] == 2 From 2f20fa04803ea93a923d188ea359b0a9ee6c9be1 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 12:34:35 +0100 Subject: [PATCH 11/23] fix: select the cursor tiebreak key by table kind, not key presence _next_cursor fell back from 'srn' to 'id' by dict lookup, but a hook may legally declare a data column named 'srn' (only id/record_srn/created_at are reserved feature auto columns). That column's string value would hijack the feature tiebreaker and 400 on the next page when coerced against the BIGINT id column. Dispatch on plan.table_kind instead: records tiebreak on srn, features on id, unconditionally. --- .../api/v1/routes/data/_streaming.py | 9 ++++--- .../api/v1/routes/data/test_streaming.py | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/server/osa/application/api/v1/routes/data/_streaming.py b/server/osa/application/api/v1/routes/data/_streaming.py index 82c6bf0..c155053 100644 --- a/server/osa/application/api/v1/routes/data/_streaming.py +++ b/server/osa/application/api/v1/routes/data/_streaming.py @@ -22,7 +22,7 @@ from osa.domain.data.model.format import DataResponseFormat from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.model.query_plan import QueryPlan, encode_cursor +from osa.domain.data.model.query_plan import QueryPlan, TableKind, encode_cursor async def build_table_response( @@ -98,8 +98,11 @@ def _next_cursor(page: list[Mapping[str, Any]], plan: QueryPlan) -> str | None: # PostgresDataReadStore). ``sort=id`` aliases to that same column in the # store, so the cursor's sort value must be the tiebreaker value too — # encoding the bare record id would compare it against the srn column, - # match every row, and never advance. - tiebreak = last.get("srn", last.get("id")) + # match every row, and never advance. Select the key by table kind, not + # by key presence: a hook may declare a data column named "srn", which + # must not hijack the feature tiebreaker. + tiebreak_key = "srn" if plan.table_kind == TableKind.RECORDS else "id" + tiebreak = last.get(tiebreak_key) sort_column = plan.sort[0].column sort_value = tiebreak if sort_column == "id" else last.get(sort_column) return encode_cursor(sort_value, tiebreak) diff --git a/server/tests/unit/application/api/v1/routes/data/test_streaming.py b/server/tests/unit/application/api/v1/routes/data/test_streaming.py index 73a18de..21a6a0f 100644 --- a/server/tests/unit/application/api/v1/routes/data/test_streaming.py +++ b/server/tests/unit/application/api/v1/routes/data/test_streaming.py @@ -153,3 +153,28 @@ async def test_paginated_features_id_sort_encodes_row_id() -> None: decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) assert decoded["s"] == 2 assert decoded["id"] == 2 + + +@pytest.mark.asyncio +async def test_paginated_feature_hook_column_named_srn_does_not_hijack_tiebreak() -> None: + # A hook may legally declare a data column named "srn" (it's not a feature + # auto column). The cursor tiebreaker must still be the integer row id — + # picking the hook column's string would 400 on the next page when it's + # coerced against the BIGINT id column. + rows = [ + {"id": 1, "record_srn": "urn:osa:localhost:rec:a@1", "srn": "hook-value-1"}, + {"id": 2, "record_srn": "urn:osa:localhost:rec:a@1", "srn": "hook-value-2"}, + {"id": 3, "record_srn": "urn:osa:localhost:rec:a@1", "srn": "hook-value-3"}, + ] + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.FEATURE, + feature_name="chem_features", + pagination={"limit": 2}, + # default FEATURE sort is id asc + ) + resp = await build_table_response(_aiter(rows), JSON_FMT, COLUMNS, plan) + parsed = json.loads(await _body(resp)) + decoded = json.loads(base64.urlsafe_b64decode(parsed["next_cursor"])) + assert decoded["s"] == 2 + assert decoded["id"] == 2 From 2c95f7eb1cc4bb024c4a67c16ae56be18f476954 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:31:22 +0100 Subject: [PATCH 12/23] chore: remove stale legacy-router docstring, narrow cursor except, relocate record test --- server/osa/application/api/v1/routes/data/__init__.py | 3 --- server/osa/domain/data/model/query_plan.py | 3 ++- .../{ => domain/record}/test_record_schema_srn_immutable.py | 0 3 files changed, 2 insertions(+), 4 deletions(-) rename server/tests/unit/{ => domain/record}/test_record_schema_srn_immutable.py (100%) diff --git a/server/osa/application/api/v1/routes/data/__init__.py b/server/osa/application/api/v1/routes/data/__init__.py index 8109e25..4234ae8 100644 --- a/server/osa/application/api/v1/routes/data/__init__.py +++ b/server/osa/application/api/v1/routes/data/__init__.py @@ -5,9 +5,6 @@ - single record by ID (``GET /data/records/{id}``) — US4 - records table matrix (``/data/{schema}/records*``) — US1/US2 via the factory - feature table matrix (``/data/{schema}/{feature}*``) — US5 via the factory - -The router coexists with the legacy ``/discovery``, ``/records``, ``/search`` -routers until US6 removes them (research §10). """ from __future__ import annotations diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py index d193e82..19c0900 100644 --- a/server/osa/domain/data/model/query_plan.py +++ b/server/osa/domain/data/model/query_plan.py @@ -103,7 +103,8 @@ def decode_cursor(cursor: str) -> dict[str, Any]: try: raw = base64.urlsafe_b64decode(cursor.encode()) data = json.loads(raw) - except Exception as exc: + except ValueError as exc: + # binascii.Error and json.JSONDecodeError are both ValueError subclasses. raise ValueError(f"Malformed cursor: {exc}") from exc if not isinstance(data, dict) or "s" not in data or "id" not in data: raise ValueError("Cursor must contain 's' and 'id' keys") diff --git a/server/tests/unit/test_record_schema_srn_immutable.py b/server/tests/unit/domain/record/test_record_schema_srn_immutable.py similarity index 100% rename from server/tests/unit/test_record_schema_srn_immutable.py rename to server/tests/unit/domain/record/test_record_schema_srn_immutable.py From 1d8f784ab2ac1426df1c8186f3e4703df6efb09b Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:33:30 +0100 Subject: [PATCH 13/23] refactor: make PaginationParams the single owner of the limit-clamp policy --- .../application/api/v1/routes/data/_params.py | 19 +++++-------------- server/osa/domain/data/model/query_plan.py | 16 ++++++++++++++-- .../tests/unit/domain/data/test_query_plan.py | 13 ++++++++++--- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/server/osa/application/api/v1/routes/data/_params.py b/server/osa/application/api/v1/routes/data/_params.py index d36c334..f45b41d 100644 --- a/server/osa/application/api/v1/routes/data/_params.py +++ b/server/osa/application/api/v1/routes/data/_params.py @@ -2,7 +2,7 @@ from __future__ import annotations -from pydantic import BaseModel, Field +from pydantic import BaseModel from osa.domain.data.model.filter import FilterExpr from osa.domain.data.model.query_plan import ( @@ -17,24 +17,15 @@ from osa.domain.shared.model.ids import HookName from osa.domain.shared.model.srn import SchemaId -# Page-size ceiling (mirrors PaginationParams.limit's bound). Requests above this -# are clamped, not rejected (FR — a consumer asking for "everything" with a big -# number gets the max page, not a 422). -MAX_LIMIT = 1000 - - -def clamp_limit(limit: int) -> int: - return max(1, min(limit, MAX_LIMIT)) - class FilterRequestBody(BaseModel): """POST body shared by every table format (records + feature).""" filter: FilterExpr | None = None cursor: str | None = None - # Unbounded at the edge so an over-large request is clamped (build_plan), - # not rejected; the canonical bound lives on PaginationParams. - limit: int = Field(default=50, ge=1) + # Unbounded at the edge: PaginationParams clamps out-of-range values to + # [1, MAX_LIMIT] — over-large requests get the max page, not a 422. + limit: int = 50 sort: str | None = None @@ -74,7 +65,7 @@ def build_plan( ) -> QueryPlan: pagination = PaginationParams( cursor=PaginationCursor(value=cursor) if cursor else None, - limit=clamp_limit(limit), + limit=limit, ) return QueryPlan( schema_id=schema_id, diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py index 19c0900..4cd946e 100644 --- a/server/osa/domain/data/model/query_plan.py +++ b/server/osa/domain/data/model/query_plan.py @@ -18,7 +18,7 @@ from enum import StrEnum from typing import Any -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from osa.domain.data.model.filter import FilterExpr from osa.domain.shared.model.ids import HookName @@ -51,9 +51,21 @@ def __str__(self) -> str: return self.value +# Page-size ceiling for paginated reads. Single source of truth — route +# helpers must not redeclare it. +MAX_LIMIT = 1000 + + class PaginationParams(BaseModel): cursor: PaginationCursor | None = None - limit: int = Field(default=50, ge=1, le=1000) + limit: int = Field(default=50, ge=1, le=MAX_LIMIT) + + @field_validator("limit", mode="before") + @classmethod + def _clamp_limit(cls, value: int) -> int: + # Clamp, don't reject: a consumer asking for "everything" with a big + # number gets the max page, not a 422. + return max(1, min(int(value), MAX_LIMIT)) # Default sort keys per table kind (data-model.md §PaginationParams). diff --git a/server/tests/unit/domain/data/test_query_plan.py b/server/tests/unit/domain/data/test_query_plan.py index 3aabf74..18e2947 100644 --- a/server/tests/unit/domain/data/test_query_plan.py +++ b/server/tests/unit/domain/data/test_query_plan.py @@ -43,9 +43,16 @@ def test_records_kind_rejects_feature_name() -> None: QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, feature_name="x") -def test_limit_capped_at_1000() -> None: - with pytest.raises(PydanticValidationError): - PaginationParams(limit=5000) +def test_limit_above_max_clamps_to_1000() -> None: + # Clamp, don't reject: a consumer asking for "everything" with a big + # number gets the max page, not a 422. The clamp policy lives HERE, on + # the canonical model — not in route helpers. + assert PaginationParams(limit=5000).limit == 1000 + + +def test_limit_below_one_clamps_to_one() -> None: + assert PaginationParams(limit=0).limit == 1 + assert PaginationParams(limit=-5).limit == 1 def test_limit_default_is_50() -> None: From ffcf49e03aff88376f3dd3e27ae80e7f6e170a1d Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:35:22 +0100 Subject: [PATCH 14/23] refactor: replace route-local id@version parsing with a RecordRef value object --- .../application/api/v1/routes/data/records.py | 21 ++--------- server/osa/domain/shared/model/ids.py | 37 ++++++++++++++++++- .../unit/domain/shared/test_record_ref.py | 28 ++++++++++++++ 3 files changed, 67 insertions(+), 19 deletions(-) create mode 100644 server/tests/unit/domain/shared/test_record_ref.py diff --git a/server/osa/application/api/v1/routes/data/records.py b/server/osa/application/api/v1/routes/data/records.py index 0910e49..40aad6d 100644 --- a/server/osa/application/api/v1/routes/data/records.py +++ b/server/osa/application/api/v1/routes/data/records.py @@ -13,32 +13,17 @@ from osa.application.api.v1.routes.data.models import RecordResponse from osa.domain.data.service.data_catalog import DataCatalogService -from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.ids import RecordRef router = APIRouter(route_class=DishkaRoute) -def _parse_id_and_version(raw: str) -> tuple[RecordId, int | None]: - """Split ``{id}`` or ``{id}@{version}`` into a typed id + optional version.""" - if "@" in raw: - id_part, version_part = raw.split("@", 1) - try: - return RecordId(id_part), int(version_part) - except ValueError as exc: - raise ValidationError( - f"Invalid record version in {raw!r}; expected an integer.", - field="id", - ) from exc - return RecordId(raw), None - - @router.get( "/records/{record_id}", operation_id="data_get_record_by_id", response_model=RecordResponse ) async def get_record_by_id( record_id: str, service: FromDishka[DataCatalogService] ) -> RecordResponse: - rid, version = _parse_id_and_version(record_id) - summary = await service.get_record_by_id(rid, version) + ref = RecordRef.parse(record_id) + summary = await service.get_record_by_id(ref.id, ref.version) return RecordResponse.from_summary(summary) diff --git a/server/osa/domain/shared/model/ids.py b/server/osa/domain/shared/model/ids.py index f48c6c9..4fb9c4c 100644 --- a/server/osa/domain/shared/model/ids.py +++ b/server/osa/domain/shared/model/ids.py @@ -11,10 +11,45 @@ from typing import NewType +from osa.domain.shared.error import ValidationError from osa.domain.shared.model.hook import HookName +from osa.domain.shared.model.value import ValueObject # Bare internal record identifier (UUIDv7 / ULID). Validation of the exact # charset/length happens at the API boundary; internally it is opaque. RecordId = NewType("RecordId", str) -__all__ = ["RecordId", "HookName"] + +class RecordRef(ValueObject): + """A record reference: bare internal id plus optional integer version. + + Wire form is the URL segment ``{id}`` or ``{id}@{version}`` — the record + analogue of :class:`~osa.domain.shared.model.srn.SchemaId`'s + ``@``. ``version is None`` means "latest published". + """ + + id: RecordId + version: int | None = None + + @classmethod + def parse(cls, raw: str) -> RecordRef: + """Parse ``{id}`` or ``{id}@{version}``; raises ``ValidationError``.""" + if "@" not in raw: + return cls(id=RecordId(raw)) + id_part, version_part = raw.split("@", 1) + try: + return cls(id=RecordId(id_part), version=int(version_part)) + except ValueError as exc: + raise ValidationError( + f"Invalid record version in {raw!r}; expected an integer.", + field="id", + ) from exc + + def render(self) -> str: + return self.id if self.version is None else f"{self.id}@{self.version}" + + def __str__(self) -> str: + return self.render() + + +__all__ = ["RecordId", "RecordRef", "HookName"] diff --git a/server/tests/unit/domain/shared/test_record_ref.py b/server/tests/unit/domain/shared/test_record_ref.py new file mode 100644 index 0000000..40769dd --- /dev/null +++ b/server/tests/unit/domain/shared/test_record_ref.py @@ -0,0 +1,28 @@ +"""RecordRef — parse of ``{id}`` / ``{id}@{version}`` URL segments.""" + +import pytest + +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.ids import RecordRef + + +def test_parse_bare_id_has_no_version() -> None: + ref = RecordRef.parse("0190a1b2") + assert ref.id == "0190a1b2" + assert ref.version is None + + +def test_parse_id_with_version() -> None: + ref = RecordRef.parse("0190a1b2@3") + assert ref.id == "0190a1b2" + assert ref.version == 3 + + +def test_parse_non_integer_version_raises_validation_error() -> None: + with pytest.raises(ValidationError): + RecordRef.parse("0190a1b2@latest") + + +def test_render_round_trips() -> None: + assert RecordRef.parse("abc@7").render() == "abc@7" + assert RecordRef.parse("abc").render() == "abc" From 711957abc9a35ec83c709aa7e9b5ef09894180f2 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:40:59 +0100 Subject: [PATCH 15/23] refactor: source the page-size ceiling from DataConfig.max_page_limit --- .../application/api/v1/routes/data/_params.py | 8 +++-- .../api/v1/routes/data/features_table.py | 5 ++++ .../api/v1/routes/data/records_table.py | 5 ++++ server/osa/config.py | 8 +++-- server/osa/domain/data/model/query_plan.py | 29 +++++++++++-------- .../api/v1/routes/data/test_params.py | 16 ++++++++++ server/tests/unit/config/test_config.py | 11 +++++++ .../tests/unit/domain/data/test_query_plan.py | 20 +++++++++---- 8 files changed, 78 insertions(+), 24 deletions(-) diff --git a/server/osa/application/api/v1/routes/data/_params.py b/server/osa/application/api/v1/routes/data/_params.py index f45b41d..73e0d05 100644 --- a/server/osa/application/api/v1/routes/data/_params.py +++ b/server/osa/application/api/v1/routes/data/_params.py @@ -23,8 +23,8 @@ class FilterRequestBody(BaseModel): filter: FilterExpr | None = None cursor: str | None = None - # Unbounded at the edge: PaginationParams clamps out-of-range values to - # [1, MAX_LIMIT] — over-large requests get the max page, not a 422. + # Unbounded at the edge: build_plan clamps to [1, DataConfig.max_page_limit] + # via PaginationParams.clamped — over-large requests get the max page, not a 422. limit: int = 50 sort: str | None = None @@ -61,11 +61,13 @@ def build_plan( filter_expr: FilterExpr | None, cursor: str | None, limit: int, + max_limit: int, sort: str | None, ) -> QueryPlan: - pagination = PaginationParams( + pagination = PaginationParams.clamped( cursor=PaginationCursor(value=cursor) if cursor else None, limit=limit, + max_limit=max_limit, ) return QueryPlan( schema_id=schema_id, diff --git a/server/osa/application/api/v1/routes/data/features_table.py b/server/osa/application/api/v1/routes/data/features_table.py index 47797a0..43e0e8a 100644 --- a/server/osa/application/api/v1/routes/data/features_table.py +++ b/server/osa/application/api/v1/routes/data/features_table.py @@ -20,6 +20,7 @@ from osa.application.api.v1.routes.data._runtime import apply_statement_timeout from osa.application.api.v1.routes.data._streaming import build_table_response from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.config import Config from osa.domain.data.model.format import DataResponseFormat from osa.domain.data.model.manifest import ColumnSpec from osa.domain.data.model.query_plan import TableKind @@ -59,6 +60,7 @@ async def endpoint( query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], session: FromDishka[AsyncSession], + config: FromDishka[Config], cursor: str | None = None, limit: int = 50, sort: str | None = None, @@ -72,6 +74,7 @@ async def endpoint( filter_expr=None, cursor=cursor, limit=limit, + max_limit=config.data.max_page_limit, sort=sort, ) rows = query_service.stream_features(plan) @@ -89,6 +92,7 @@ async def endpoint( query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], session: FromDishka[AsyncSession], + config: FromDishka[Config], ) -> StreamingResponse: columns, schema_id = await _feature_columns(catalog_service, schema, feature) await apply_statement_timeout(session, fmt) @@ -99,6 +103,7 @@ async def endpoint( filter_expr=body.filter, cursor=body.cursor, limit=body.limit, + max_limit=config.data.max_page_limit, sort=body.sort, ) rows = query_service.stream_features(plan) diff --git a/server/osa/application/api/v1/routes/data/records_table.py b/server/osa/application/api/v1/routes/data/records_table.py index 9567d27..6004d4c 100644 --- a/server/osa/application/api/v1/routes/data/records_table.py +++ b/server/osa/application/api/v1/routes/data/records_table.py @@ -21,6 +21,7 @@ from osa.application.api.v1.routes.data._runtime import apply_statement_timeout from osa.application.api.v1.routes.data._streaming import build_table_response from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.config import Config from osa.domain.data.model.format import DataResponseFormat from osa.domain.data.model.manifest import ColumnSpec from osa.domain.data.model.query_plan import TableKind @@ -45,6 +46,7 @@ async def endpoint( query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], session: FromDishka[AsyncSession], + config: FromDishka[Config], cursor: str | None = None, limit: int = 50, sort: str | None = None, @@ -58,6 +60,7 @@ async def endpoint( filter_expr=None, cursor=cursor, limit=limit, + max_limit=config.data.max_page_limit, sort=sort, ) rows = query_service.stream_records(plan) @@ -74,6 +77,7 @@ async def endpoint( query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], session: FromDishka[AsyncSession], + config: FromDishka[Config], ) -> StreamingResponse: columns, schema_id = await _records_columns(catalog_service, schema) await apply_statement_timeout(session, fmt) @@ -84,6 +88,7 @@ async def endpoint( filter_expr=body.filter, cursor=body.cursor, limit=body.limit, + max_limit=config.data.max_page_limit, sort=body.sort, ) rows = query_service.stream_records(plan) diff --git a/server/osa/config.py b/server/osa/config.py index fc44cba..69e15ee 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -263,14 +263,16 @@ def _normalize_pg_url(url: str) -> str: class DataConfig(BaseModel): """Bounds for the unified ``/data/`` read surface (nested in Config). - Caps the cost of a filter tree before it is compiled to SQL: maximum tree - depth, total predicate count, and distinct feature-hook joins. Overridable - via ``OSA_DATA__MAX_FILTER_DEPTH`` etc. (FR-012, features 076/137). + Caps the cost of a filter tree before it is compiled to SQL (maximum tree + depth, total predicate count, distinct feature-hook joins) and the + paginated-JSON page size. Overridable via ``OSA_DATA__MAX_FILTER_DEPTH`` + etc. (FR-012, features 076/137). """ max_filter_depth: int = 10 max_predicates: int = 200 max_feature_joins: int = 10 # distinct features. references in one filter + max_page_limit: int = 1000 # page-size ceiling; over-large requests are clamped, not 422d class Config(BaseSettings): diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py index 4cd946e..f2eb846 100644 --- a/server/osa/domain/data/model/query_plan.py +++ b/server/osa/domain/data/model/query_plan.py @@ -18,7 +18,7 @@ from enum import StrEnum from typing import Any -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, model_validator from osa.domain.data.model.filter import FilterExpr from osa.domain.shared.model.ids import HookName @@ -51,21 +51,26 @@ def __str__(self) -> str: return self.value -# Page-size ceiling for paginated reads. Single source of truth — route -# helpers must not redeclare it. -MAX_LIMIT = 1000 - - class PaginationParams(BaseModel): cursor: PaginationCursor | None = None - limit: int = Field(default=50, ge=1, le=MAX_LIMIT) + limit: int = Field(default=50, ge=1) - @field_validator("limit", mode="before") @classmethod - def _clamp_limit(cls, value: int) -> int: - # Clamp, don't reject: a consumer asking for "everything" with a big - # number gets the max page, not a 422. - return max(1, min(int(value), MAX_LIMIT)) + def clamped( + cls, + *, + cursor: PaginationCursor | None = None, + limit: int, + max_limit: int, + ) -> "PaginationParams": + """Build params with ``limit`` clamped into ``[1, max_limit]``. + + Clamp, don't reject: a consumer asking for "everything" with a big + number gets the max page, not a 422. The ceiling is operator + configuration (``DataConfig.max_page_limit``), so it arrives as an + argument rather than living here as a constant. + """ + return cls(cursor=cursor, limit=max(1, min(limit, max_limit))) # Default sort keys per table kind (data-model.md §PaginationParams). diff --git a/server/tests/unit/application/api/v1/routes/data/test_params.py b/server/tests/unit/application/api/v1/routes/data/test_params.py index 7895a92..d2380b2 100644 --- a/server/tests/unit/application/api/v1/routes/data/test_params.py +++ b/server/tests/unit/application/api/v1/routes/data/test_params.py @@ -42,6 +42,7 @@ def test_build_plan_wraps_cursor_and_limit() -> None: filter_expr=None, cursor="CUR", limit=10, + max_limit=1000, sort="mw:asc", ) assert plan.pagination.limit == 10 @@ -49,6 +50,20 @@ def test_build_plan_wraps_cursor_and_limit() -> None: assert plan.sort[0].column == "mw" +def test_build_plan_clamps_limit_to_configured_max() -> None: + plan = build_plan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + feature_name=None, + filter_expr=None, + cursor=None, + limit=99999, + max_limit=200, + sort=None, + ) + assert plan.pagination.limit == 200 + + def test_build_plan_no_cursor_is_none() -> None: plan = build_plan( schema_id=SCHEMA, @@ -57,6 +72,7 @@ def test_build_plan_no_cursor_is_none() -> None: filter_expr=None, cursor=None, limit=50, + max_limit=1000, sort=None, ) assert plan.pagination.cursor is None diff --git a/server/tests/unit/config/test_config.py b/server/tests/unit/config/test_config.py index e5fb1cb..de537b2 100644 --- a/server/tests/unit/config/test_config.py +++ b/server/tests/unit/config/test_config.py @@ -168,3 +168,14 @@ def test_empty_providers_default(self): """auth.providers defaults to empty OrcidConfig.""" cfg = config_from_yaml({}) assert cfg.auth.providers.orcid.client_id == "" + + +class TestDataConfig: + """The /data/ page-size ceiling is operator configuration, not a code constant.""" + + def test_max_page_limit_default(self): + assert config_from_yaml({}).data.max_page_limit == 1000 + + def test_max_page_limit_yaml_override(self): + cfg = config_from_yaml({"data": {"max_page_limit": 250}}) + assert cfg.data.max_page_limit == 250 diff --git a/server/tests/unit/domain/data/test_query_plan.py b/server/tests/unit/domain/data/test_query_plan.py index 18e2947..3419718 100644 --- a/server/tests/unit/domain/data/test_query_plan.py +++ b/server/tests/unit/domain/data/test_query_plan.py @@ -6,6 +6,7 @@ from pydantic import ValidationError as PydanticValidationError from osa.domain.data.model.query_plan import ( + PaginationCursor, PaginationParams, QueryPlan, SortDirection, @@ -43,16 +44,23 @@ def test_records_kind_rejects_feature_name() -> None: QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS, feature_name="x") -def test_limit_above_max_clamps_to_1000() -> None: +def test_limit_above_max_clamps_to_max() -> None: # Clamp, don't reject: a consumer asking for "everything" with a big - # number gets the max page, not a 422. The clamp policy lives HERE, on - # the canonical model — not in route helpers. - assert PaginationParams(limit=5000).limit == 1000 + # number gets the max page, not a 422. The clamp *policy* lives HERE, on + # the canonical model; the *bound* comes from config (DataConfig). + assert PaginationParams.clamped(limit=5000, max_limit=1000).limit == 1000 + assert PaginationParams.clamped(limit=5000, max_limit=200).limit == 200 def test_limit_below_one_clamps_to_one() -> None: - assert PaginationParams(limit=0).limit == 1 - assert PaginationParams(limit=-5).limit == 1 + assert PaginationParams.clamped(limit=0, max_limit=1000).limit == 1 + assert PaginationParams.clamped(limit=-5, max_limit=1000).limit == 1 + + +def test_clamped_preserves_cursor() -> None: + p = PaginationParams.clamped(cursor=PaginationCursor(value="CUR"), limit=10, max_limit=1000) + assert str(p.cursor) == "CUR" + assert p.limit == 10 def test_limit_default_is_50() -> None: From 5c340f6f3121c22626871adb932de355154fc52f Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:46:12 +0100 Subject: [PATCH 16/23] refactor: move manifest-structure knowledge into DataCatalogService.resolve_table --- .../api/v1/routes/data/features_table.py | 40 +----- .../api/v1/routes/data/records_table.py | 26 +--- server/osa/domain/data/model/manifest.py | 14 +++ .../osa/domain/data/service/data_catalog.py | 33 ++++- .../domain/data/test_data_catalog_service.py | 119 ++++++++++++++++++ 5 files changed, 176 insertions(+), 56 deletions(-) create mode 100644 server/tests/unit/domain/data/test_data_catalog_service.py diff --git a/server/osa/application/api/v1/routes/data/features_table.py b/server/osa/application/api/v1/routes/data/features_table.py index 43e0e8a..310d081 100644 --- a/server/osa/application/api/v1/routes/data/features_table.py +++ b/server/osa/application/api/v1/routes/data/features_table.py @@ -8,8 +8,6 @@ from __future__ import annotations -from collections.abc import Sequence - from dishka.integrations.fastapi import FromDishka from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse @@ -22,35 +20,9 @@ from osa.application.api.v1.routes.data.tables import format_key, register_table_routes from osa.config import Config from osa.domain.data.model.format import DataResponseFormat -from osa.domain.data.model.manifest import ColumnSpec from osa.domain.data.model.query_plan import TableKind from osa.domain.data.service.data_catalog import DataCatalogService from osa.domain.data.service.data_query import DataQueryService -from osa.domain.shared.error import NotFoundError -from osa.domain.shared.model.srn import SchemaId - - -async def _feature_columns( - catalog: DataCatalogService, schema: str, feature: str -) -> tuple[Sequence[ColumnSpec], SchemaId]: - """Resolve schema + feature (404 if unknown/reserved) → feature columns + SchemaId.""" - schema_id = await catalog.resolve_schema(schema) - manifest = await catalog.get_schema_manifest(schema_id) - resource = next( - ( - tr - for tr in manifest.table_resources - if tr.name == feature and tr.kind == TableKind.FEATURE - ), - None, - ) - if resource is None: - raise NotFoundError( - f"No feature table '{feature}' on schema '{schema}'. " - f"See /api/v1/data/{schema_id.render()} for its table resources.", - code="feature_not_found", - ) - return resource.columns, schema_id def _make_get_endpoint(fmt: DataResponseFormat): @@ -65,10 +37,10 @@ async def endpoint( limit: int = 50, sort: str | None = None, ) -> StreamingResponse: - columns, schema_id = await _feature_columns(catalog_service, schema, feature) + table = await catalog_service.resolve_table(schema, TableKind.FEATURE, feature_name=feature) await apply_statement_timeout(session, fmt) plan = build_plan( - schema_id=schema_id, + schema_id=table.schema_id, table_kind=TableKind.FEATURE, feature_name=feature, filter_expr=None, @@ -78,7 +50,7 @@ async def endpoint( sort=sort, ) rows = query_service.stream_features(plan) - return await build_table_response(rows, fmt, columns, plan) + return await build_table_response(rows, fmt, table.columns, plan) return endpoint @@ -94,10 +66,10 @@ async def endpoint( session: FromDishka[AsyncSession], config: FromDishka[Config], ) -> StreamingResponse: - columns, schema_id = await _feature_columns(catalog_service, schema, feature) + table = await catalog_service.resolve_table(schema, TableKind.FEATURE, feature_name=feature) await apply_statement_timeout(session, fmt) plan = build_plan( - schema_id=schema_id, + schema_id=table.schema_id, table_kind=TableKind.FEATURE, feature_name=feature, filter_expr=body.filter, @@ -107,7 +79,7 @@ async def endpoint( sort=body.sort, ) rows = query_service.stream_features(plan) - return await build_table_response(rows, fmt, columns, plan) + return await build_table_response(rows, fmt, table.columns, plan) # Unique name before the limiter — see records_table._make_post_endpoint. endpoint.__name__ = f"feature_post_{format_key(fmt)}" diff --git a/server/osa/application/api/v1/routes/data/records_table.py b/server/osa/application/api/v1/routes/data/records_table.py index 6004d4c..ee498ae 100644 --- a/server/osa/application/api/v1/routes/data/records_table.py +++ b/server/osa/application/api/v1/routes/data/records_table.py @@ -9,8 +9,6 @@ from __future__ import annotations -from collections.abc import Sequence - from dishka.integrations.fastapi import FromDishka from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse @@ -23,21 +21,9 @@ from osa.application.api.v1.routes.data.tables import format_key, register_table_routes from osa.config import Config from osa.domain.data.model.format import DataResponseFormat -from osa.domain.data.model.manifest import ColumnSpec from osa.domain.data.model.query_plan import TableKind from osa.domain.data.service.data_catalog import DataCatalogService from osa.domain.data.service.data_query import DataQueryService -from osa.domain.shared.model.srn import SchemaId - - -async def _records_columns( - catalog: DataCatalogService, schema: str -) -> tuple[Sequence[ColumnSpec], SchemaId]: - """Resolve schema (404 if unknown/reserved) and return its records columns + SchemaId.""" - schema_id = await catalog.resolve_schema(schema) - manifest = await catalog.get_schema_manifest(schema_id) - records_resource = next(tr for tr in manifest.table_resources if tr.name == "records") - return records_resource.columns, schema_id def _make_get_endpoint(fmt: DataResponseFormat): @@ -51,10 +37,10 @@ async def endpoint( limit: int = 50, sort: str | None = None, ) -> StreamingResponse: - columns, schema_id = await _records_columns(catalog_service, schema) + table = await catalog_service.resolve_table(schema, TableKind.RECORDS) await apply_statement_timeout(session, fmt) plan = build_plan( - schema_id=schema_id, + schema_id=table.schema_id, table_kind=TableKind.RECORDS, feature_name=None, filter_expr=None, @@ -64,7 +50,7 @@ async def endpoint( sort=sort, ) rows = query_service.stream_records(plan) - return await build_table_response(rows, fmt, columns, plan) + return await build_table_response(rows, fmt, table.columns, plan) return endpoint @@ -79,10 +65,10 @@ async def endpoint( session: FromDishka[AsyncSession], config: FromDishka[Config], ) -> StreamingResponse: - columns, schema_id = await _records_columns(catalog_service, schema) + table = await catalog_service.resolve_table(schema, TableKind.RECORDS) await apply_statement_timeout(session, fmt) plan = build_plan( - schema_id=schema_id, + schema_id=table.schema_id, table_kind=TableKind.RECORDS, feature_name=None, filter_expr=body.filter, @@ -92,7 +78,7 @@ async def endpoint( sort=body.sort, ) rows = query_service.stream_records(plan) - return await build_table_response(rows, fmt, columns, plan) + return await build_table_response(rows, fmt, table.columns, plan) # slowapi scopes a rate limit by the decorated function's ``module.__name__`` # captured at decoration time. These builder closures are all named diff --git a/server/osa/domain/data/model/manifest.py b/server/osa/domain/data/model/manifest.py index 30046a7..749cb61 100644 --- a/server/osa/domain/data/model/manifest.py +++ b/server/osa/domain/data/model/manifest.py @@ -12,6 +12,7 @@ from osa.domain.data.model.query_plan import TableKind from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import SchemaId class FieldSpec(BaseModel): @@ -68,3 +69,16 @@ class SchemaManifest(BaseModel): srn: str # full schema SRN fields: list[FieldSpec] table_resources: list[TableResource] + + +class ResolvedTable(BaseModel): + """A table resolved for reading: the owning schema plus its column schema. + + Produced by ``DataCatalogService.resolve_table`` — the single owner of + manifest-structure knowledge (the records resource is named ``records``; + feature resources carry kind ``FEATURE``). Route code consumes this + instead of picking through ``table_resources`` itself. + """ + + schema_id: SchemaId + columns: list[ColumnSpec] diff --git a/server/osa/domain/data/service/data_catalog.py b/server/osa/domain/data/service/data_catalog.py index ba2fb49..2c4abe8 100644 --- a/server/osa/domain/data/service/data_catalog.py +++ b/server/osa/domain/data/service/data_catalog.py @@ -8,11 +8,12 @@ from __future__ import annotations from osa.domain.data.model.catalog import NodeCatalog -from osa.domain.data.model.manifest import SchemaManifest +from osa.domain.data.model.manifest import ResolvedTable, SchemaManifest +from osa.domain.data.model.query_plan import TableKind from osa.domain.data.model.record_summary import RecordSummary from osa.domain.data.port.data_read_store import DataReadStore from osa.domain.shared.error import NotFoundError -from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.ids import HookName, RecordId from osa.domain.shared.model.reserved import RESERVED_NAMES from osa.domain.shared.model.srn import SchemaId from osa.domain.shared.service import Service @@ -67,6 +68,34 @@ async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest: ) return manifest + async def resolve_table( + self, + schema: str, + table_kind: TableKind, + feature_name: HookName | None = None, + ) -> ResolvedTable: + """Resolve a URL schema segment + table selector to its column schema. + + Owns the manifest-structure facts the routes must not re-derive: the + records resource is always named ``records``; a feature resource is + matched by name AND kind (a hook cannot shadow the records slot). + Unknown schema or table raises ``NotFoundError`` (404 before bytes). + """ + schema_id = await self.resolve_schema(schema) + manifest = await self.get_schema_manifest(schema_id) + name = "records" if table_kind == TableKind.RECORDS else feature_name + resource = next( + (tr for tr in manifest.table_resources if tr.name == name and tr.kind == table_kind), + None, + ) + if resource is None: + raise NotFoundError( + f"No table '{name}' on schema '{schema}'. " + f"See /api/v1/data/{schema_id.render()} for its table resources.", + code="table_not_found", + ) + return ResolvedTable(schema_id=schema_id, columns=resource.columns) + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary: record = await self.read_store.get_record_by_id(id, version) if record is None: diff --git a/server/tests/unit/domain/data/test_data_catalog_service.py b/server/tests/unit/domain/data/test_data_catalog_service.py new file mode 100644 index 0000000..39f63e8 --- /dev/null +++ b/server/tests/unit/domain/data/test_data_catalog_service.py @@ -0,0 +1,119 @@ +"""DataCatalogService.resolve_table — one owner for "which columns does this table have". + +The records/feature route files previously each held a private helper that +re-derived manifest structure (records resource is named "records"; a feature +resource has kind FEATURE). That knowledge belongs to the catalog service. +""" + +from collections.abc import AsyncIterator, Mapping +from typing import Any + +import pytest + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import ( + ColumnSpec, + ResolvedTable, + SchemaManifest, + TableResource, +) +from osa.domain.data.model.query_plan import QueryPlan, TableKind +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.model.ids import HookName, RecordId +from osa.domain.shared.model.srn import SchemaId + +RECORDS_COLUMNS = [ + ColumnSpec(name="id", type=FieldType.TEXT), + ColumnSpec(name="srn", type=FieldType.TEXT), +] +FEATURE_COLUMNS = [ + ColumnSpec(name="id", type=FieldType.NUMBER), + ColumnSpec(name="mw", type=FieldType.NUMBER), +] + +MANIFEST = SchemaManifest( + id="compound", + version="1.0.0", + srn="urn:osa:localhost:schema:compound@1.0.0", + fields=[], + table_resources=[ + TableResource( + name="records", + kind=TableKind.RECORDS, + columns=RECORDS_COLUMNS, + row_count=2, + formats=["", "csv", "csv.gz"], + ), + TableResource( + name="chem_features", + kind=TableKind.FEATURE, + columns=FEATURE_COLUMNS, + row_count=2, + formats=["", "csv", "csv.gz"], + ), + ], +) + + +class FakeReadStore: + """Minimal DataReadStore fake for catalog reads.""" + + def __init__(self, manifest: SchemaManifest | None = MANIFEST) -> None: + self.manifest = manifest + + def stream_rows(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + raise NotImplementedError + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: + return None + + async def get_node_catalog(self) -> NodeCatalog: + raise NotImplementedError + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | None: + return self.manifest + + async def get_latest_schema_id(self, schema_short_id: str) -> SchemaId | None: + return SchemaId.parse(f"{schema_short_id}@1.0.0") + + +def _service(store: FakeReadStore | None = None) -> DataCatalogService: + return DataCatalogService(read_store=store or FakeReadStore()) + + +@pytest.mark.asyncio +async def test_resolve_table_records() -> None: + resolved = await _service().resolve_table("compound@1.0.0", TableKind.RECORDS) + assert isinstance(resolved, ResolvedTable) + assert resolved.schema_id == SchemaId.parse("compound@1.0.0") + assert resolved.columns == RECORDS_COLUMNS + + +@pytest.mark.asyncio +async def test_resolve_table_feature() -> None: + resolved = await _service().resolve_table( + "compound@1.0.0", TableKind.FEATURE, feature_name=HookName("chem_features") + ) + assert resolved.columns == FEATURE_COLUMNS + + +@pytest.mark.asyncio +async def test_resolve_table_unknown_feature_404s() -> None: + with pytest.raises(NotFoundError) as exc: + await _service().resolve_table( + "compound@1.0.0", TableKind.FEATURE, feature_name=HookName("nope") + ) + assert exc.value.code == "table_not_found" + + +@pytest.mark.asyncio +async def test_resolve_table_kind_must_match() -> None: + # A resource named "records" exists, but with kind RECORDS — asking for a + # FEATURE of that name must not match it. + with pytest.raises(NotFoundError): + await _service().resolve_table( + "compound@1.0.0", TableKind.FEATURE, feature_name=HookName("records") + ) From 522323d24b6a0df997df8a6dae5736d6090aeb2a Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:51:58 +0100 Subject: [PATCH 17/23] refactor: move serializers/format registry to the route layer; adapter owns statement timeout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DataResponseFormat carried an HTTP media type and a Postgres timeout literal inside domain/data/model — presentation and infrastructure concerns. The registry and serializers now live with the routes that consume them; the media type is derived from the serializer (single owner); the timeout is a timedelta execution budget threaded through the DataReadStore port and applied by the Postgres adapter via SET LOCAL (integer ms). Routes no longer inject AsyncSession; _runtime.py is gone. --- .../api/v1/routes/data/_runtime.py | 32 ---------- .../api/v1/routes/data/_streaming.py | 2 +- .../api/v1/routes/data/features_table.py | 12 +--- .../application/api/v1/routes/data/formats.py | 60 +++++++++++++++++++ .../api/v1/routes/data/records_table.py | 12 +--- .../v1/routes/data/serializers}/__init__.py | 0 .../api/v1/routes/data/serializers}/csv.py | 0 .../v1/routes/data/serializers}/csv_gzip.py | 2 +- .../api/v1/routes/data/serializers}/json.py | 0 .../v1/routes/data/serializers}/protocol.py | 0 .../application/api/v1/routes/data/tables.py | 2 +- server/osa/domain/data/model/format.py | 54 ----------------- .../osa/domain/data/port/data_read_store.py | 11 +++- server/osa/domain/data/service/data_query.py | 13 ++-- .../data/postgres_data_read_store.py | 22 ++++++- .../v1/routes/data/serializers}/__init__.py | 0 .../v1/routes/data/serializers}/test_csv.py | 2 +- .../routes/data/serializers}/test_csv_gzip.py | 2 +- .../v1/routes/data/serializers}/test_json.py | 2 +- .../api/v1/routes/data/test_streaming.py | 2 +- .../api/v1/routes/data/test_tables_factory.py | 2 +- .../domain/data/test_data_query_service.py | 26 +++++++- .../data/test_statement_timeout.py | 20 +++++++ 23 files changed, 156 insertions(+), 122 deletions(-) delete mode 100644 server/osa/application/api/v1/routes/data/_runtime.py create mode 100644 server/osa/application/api/v1/routes/data/formats.py rename server/osa/{domain/data/serializer => application/api/v1/routes/data/serializers}/__init__.py (100%) rename server/osa/{domain/data/serializer => application/api/v1/routes/data/serializers}/csv.py (100%) rename server/osa/{domain/data/serializer => application/api/v1/routes/data/serializers}/csv_gzip.py (94%) rename server/osa/{domain/data/serializer => application/api/v1/routes/data/serializers}/json.py (100%) rename server/osa/{domain/data/serializer => application/api/v1/routes/data/serializers}/protocol.py (100%) delete mode 100644 server/osa/domain/data/model/format.py rename server/tests/unit/{domain/data/serializer => application/api/v1/routes/data/serializers}/__init__.py (100%) rename server/tests/unit/{domain/data/serializer => application/api/v1/routes/data/serializers}/test_csv.py (96%) rename server/tests/unit/{domain/data/serializer => application/api/v1/routes/data/serializers}/test_csv_gzip.py (95%) rename server/tests/unit/{domain/data/serializer => application/api/v1/routes/data/serializers}/test_json.py (95%) create mode 100644 server/tests/unit/infrastructure/data/test_statement_timeout.py diff --git a/server/osa/application/api/v1/routes/data/_runtime.py b/server/osa/application/api/v1/routes/data/_runtime.py deleted file mode 100644 index 9149f96..0000000 --- a/server/osa/application/api/v1/routes/data/_runtime.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Per-route Postgres statement-timeout dependency (research §7). - -Each :class:`DataResponseFormat` carries its own ``statement_timeout`` (30s for -paginated JSON, 30min for streaming dumps). ``apply_statement_timeout`` issues -``SET LOCAL statement_timeout = '...'`` against the active session at the start -of the request; ``SET LOCAL`` reverts on commit/rollback so the value never -leaks to a later request reusing the pooled connection. - -``idle_in_transaction_session_timeout`` is set globally in Postgres config (see -quickstart), not per-route — it's the same everywhere and protects the pool -from any route's misbehaviour. -""" - -from __future__ import annotations - -import re - -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncSession - -from osa.domain.data.model.format import DataResponseFormat - -# Timeout literals are operator-controlled (from FORMATS), never user input, but -# validate anyway so a malformed registry value can't reach raw SQL. -_TIMEOUT_RE = re.compile(r"^\d+(ms|s|min)?$") - - -async def apply_statement_timeout(session: AsyncSession, fmt: DataResponseFormat) -> None: - timeout = fmt.statement_timeout - if not _TIMEOUT_RE.match(timeout): - raise ValueError(f"Invalid statement_timeout literal: {timeout!r}") - await session.execute(text(f"SET LOCAL statement_timeout = '{timeout}'")) diff --git a/server/osa/application/api/v1/routes/data/_streaming.py b/server/osa/application/api/v1/routes/data/_streaming.py index c155053..142df6e 100644 --- a/server/osa/application/api/v1/routes/data/_streaming.py +++ b/server/osa/application/api/v1/routes/data/_streaming.py @@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse -from osa.domain.data.model.format import DataResponseFormat +from osa.application.api.v1.routes.data.formats import DataResponseFormat from osa.domain.data.model.manifest import ColumnSpec from osa.domain.data.model.query_plan import QueryPlan, TableKind, encode_cursor diff --git a/server/osa/application/api/v1/routes/data/features_table.py b/server/osa/application/api/v1/routes/data/features_table.py index 310d081..ef8a124 100644 --- a/server/osa/application/api/v1/routes/data/features_table.py +++ b/server/osa/application/api/v1/routes/data/features_table.py @@ -11,15 +11,13 @@ from dishka.integrations.fastapi import FromDishka from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse -from sqlalchemy.ext.asyncio import AsyncSession from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter from osa.application.api.v1.routes.data._params import FilterRequestBody, build_plan -from osa.application.api.v1.routes.data._runtime import apply_statement_timeout from osa.application.api.v1.routes.data._streaming import build_table_response from osa.application.api.v1.routes.data.tables import format_key, register_table_routes from osa.config import Config -from osa.domain.data.model.format import DataResponseFormat +from osa.application.api.v1.routes.data.formats import DataResponseFormat from osa.domain.data.model.query_plan import TableKind from osa.domain.data.service.data_catalog import DataCatalogService from osa.domain.data.service.data_query import DataQueryService @@ -31,14 +29,12 @@ async def endpoint( feature: str, query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], - session: FromDishka[AsyncSession], config: FromDishka[Config], cursor: str | None = None, limit: int = 50, sort: str | None = None, ) -> StreamingResponse: table = await catalog_service.resolve_table(schema, TableKind.FEATURE, feature_name=feature) - await apply_statement_timeout(session, fmt) plan = build_plan( schema_id=table.schema_id, table_kind=TableKind.FEATURE, @@ -49,7 +45,7 @@ async def endpoint( max_limit=config.data.max_page_limit, sort=sort, ) - rows = query_service.stream_features(plan) + rows = query_service.stream_features(plan, timeout=fmt.timeout) return await build_table_response(rows, fmt, table.columns, plan) return endpoint @@ -63,11 +59,9 @@ async def endpoint( body: FilterRequestBody, query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], - session: FromDishka[AsyncSession], config: FromDishka[Config], ) -> StreamingResponse: table = await catalog_service.resolve_table(schema, TableKind.FEATURE, feature_name=feature) - await apply_statement_timeout(session, fmt) plan = build_plan( schema_id=table.schema_id, table_kind=TableKind.FEATURE, @@ -78,7 +72,7 @@ async def endpoint( max_limit=config.data.max_page_limit, sort=body.sort, ) - rows = query_service.stream_features(plan) + rows = query_service.stream_features(plan, timeout=fmt.timeout) return await build_table_response(rows, fmt, table.columns, plan) # Unique name before the limiter — see records_table._make_post_endpoint. diff --git a/server/osa/application/api/v1/routes/data/formats.py b/server/osa/application/api/v1/routes/data/formats.py new file mode 100644 index 0000000..311d6ba --- /dev/null +++ b/server/osa/application/api/v1/routes/data/formats.py @@ -0,0 +1,60 @@ +"""Response-format registry for the ``/data/`` surface. + +Each :class:`DataResponseFormat` ties a URL suffix to its serializer, +pagination semantics, and per-route statement-timeout budget (research §7). +Adding a format (e.g. ``ndjson``, ``parquet``) is a one-line append plus one +serializer class — the route factory picks it up automatically. + +This is presentation configuration, so it lives with the routes: media types +and Postgres execution budgets are HTTP/adapter concerns, not domain model. +The domain's :class:`~osa.domain.data.model.query_plan.QueryPlan` stays a pure +query description. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta + +from osa.application.api.v1.routes.data.serializers.csv import CsvSerializer +from osa.application.api.v1.routes.data.serializers.csv_gzip import CsvGzipSerializer +from osa.application.api.v1.routes.data.serializers.json import JsonSerializer +from osa.application.api.v1.routes.data.serializers.protocol import Serializer + + +@dataclass(frozen=True) +class DataResponseFormat: + serializer_cls: type[Serializer] + paginated: bool # True → JSON envelope + cursor; False → unbounded stream + suffix: str # "" (json), "csv", "csv.gz" + timeout: timedelta # statement budget: short for paginated, long for dumps + + @property + def media_type(self) -> str: + # The serializer is the single owner of its media type. + return self.serializer_cls.media_type + + def make_serializer(self) -> Serializer: + return self.serializer_cls() + + +FORMATS: tuple[DataResponseFormat, ...] = ( + DataResponseFormat( + serializer_cls=JsonSerializer, + paginated=True, + suffix="", + timeout=timedelta(seconds=30), + ), + DataResponseFormat( + serializer_cls=CsvSerializer, + paginated=False, + suffix="csv", + timeout=timedelta(minutes=30), + ), + DataResponseFormat( + serializer_cls=CsvGzipSerializer, + paginated=False, + suffix="csv.gz", + timeout=timedelta(minutes=30), + ), +) diff --git a/server/osa/application/api/v1/routes/data/records_table.py b/server/osa/application/api/v1/routes/data/records_table.py index ee498ae..6420fd3 100644 --- a/server/osa/application/api/v1/routes/data/records_table.py +++ b/server/osa/application/api/v1/routes/data/records_table.py @@ -12,15 +12,13 @@ from dishka.integrations.fastapi import FromDishka from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse -from sqlalchemy.ext.asyncio import AsyncSession from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter from osa.application.api.v1.routes.data._params import FilterRequestBody, build_plan -from osa.application.api.v1.routes.data._runtime import apply_statement_timeout from osa.application.api.v1.routes.data._streaming import build_table_response from osa.application.api.v1.routes.data.tables import format_key, register_table_routes from osa.config import Config -from osa.domain.data.model.format import DataResponseFormat +from osa.application.api.v1.routes.data.formats import DataResponseFormat from osa.domain.data.model.query_plan import TableKind from osa.domain.data.service.data_catalog import DataCatalogService from osa.domain.data.service.data_query import DataQueryService @@ -31,14 +29,12 @@ async def endpoint( schema: str, query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], - session: FromDishka[AsyncSession], config: FromDishka[Config], cursor: str | None = None, limit: int = 50, sort: str | None = None, ) -> StreamingResponse: table = await catalog_service.resolve_table(schema, TableKind.RECORDS) - await apply_statement_timeout(session, fmt) plan = build_plan( schema_id=table.schema_id, table_kind=TableKind.RECORDS, @@ -49,7 +45,7 @@ async def endpoint( max_limit=config.data.max_page_limit, sort=sort, ) - rows = query_service.stream_records(plan) + rows = query_service.stream_records(plan, timeout=fmt.timeout) return await build_table_response(rows, fmt, table.columns, plan) return endpoint @@ -62,11 +58,9 @@ async def endpoint( body: FilterRequestBody, query_service: FromDishka[DataQueryService], catalog_service: FromDishka[DataCatalogService], - session: FromDishka[AsyncSession], config: FromDishka[Config], ) -> StreamingResponse: table = await catalog_service.resolve_table(schema, TableKind.RECORDS) - await apply_statement_timeout(session, fmt) plan = build_plan( schema_id=table.schema_id, table_kind=TableKind.RECORDS, @@ -77,7 +71,7 @@ async def endpoint( max_limit=config.data.max_page_limit, sort=body.sort, ) - rows = query_service.stream_records(plan) + rows = query_service.stream_records(plan, timeout=fmt.timeout) return await build_table_response(rows, fmt, table.columns, plan) # slowapi scopes a rate limit by the decorated function's ``module.__name__`` diff --git a/server/osa/domain/data/serializer/__init__.py b/server/osa/application/api/v1/routes/data/serializers/__init__.py similarity index 100% rename from server/osa/domain/data/serializer/__init__.py rename to server/osa/application/api/v1/routes/data/serializers/__init__.py diff --git a/server/osa/domain/data/serializer/csv.py b/server/osa/application/api/v1/routes/data/serializers/csv.py similarity index 100% rename from server/osa/domain/data/serializer/csv.py rename to server/osa/application/api/v1/routes/data/serializers/csv.py diff --git a/server/osa/domain/data/serializer/csv_gzip.py b/server/osa/application/api/v1/routes/data/serializers/csv_gzip.py similarity index 94% rename from server/osa/domain/data/serializer/csv_gzip.py rename to server/osa/application/api/v1/routes/data/serializers/csv_gzip.py index f6534d7..4f159fe 100644 --- a/server/osa/domain/data/serializer/csv_gzip.py +++ b/server/osa/application/api/v1/routes/data/serializers/csv_gzip.py @@ -15,7 +15,7 @@ from typing import Any, ClassVar from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.serializer.csv import CsvSerializer +from osa.application.api.v1.routes.data.serializers.csv import CsvSerializer class CsvGzipSerializer: diff --git a/server/osa/domain/data/serializer/json.py b/server/osa/application/api/v1/routes/data/serializers/json.py similarity index 100% rename from server/osa/domain/data/serializer/json.py rename to server/osa/application/api/v1/routes/data/serializers/json.py diff --git a/server/osa/domain/data/serializer/protocol.py b/server/osa/application/api/v1/routes/data/serializers/protocol.py similarity index 100% rename from server/osa/domain/data/serializer/protocol.py rename to server/osa/application/api/v1/routes/data/serializers/protocol.py diff --git a/server/osa/application/api/v1/routes/data/tables.py b/server/osa/application/api/v1/routes/data/tables.py index 9fcba32..42c41a5 100644 --- a/server/osa/application/api/v1/routes/data/tables.py +++ b/server/osa/application/api/v1/routes/data/tables.py @@ -20,7 +20,7 @@ from fastapi import APIRouter -from osa.domain.data.model.format import FORMATS, DataResponseFormat +from osa.application.api.v1.routes.data.formats import FORMATS, DataResponseFormat EndpointBuilder = Callable[[DataResponseFormat], Callable[..., Any]] diff --git a/server/osa/domain/data/model/format.py b/server/osa/domain/data/model/format.py deleted file mode 100644 index 276d74c..0000000 --- a/server/osa/domain/data/model/format.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Response-format registry for the ``/data/`` surface. - -Each :class:`DataResponseFormat` ties a URL suffix to its serializer, media -type, pagination semantics, and per-route statement timeout (research §7). -Adding a format (e.g. ``ndjson``, ``parquet``) is a one-line append plus one -serializer class — the route factory and timeout dependency pick it up -automatically. -""" - -from __future__ import annotations - -from dataclasses import dataclass - -from osa.domain.data.serializer.csv import CsvSerializer -from osa.domain.data.serializer.csv_gzip import CsvGzipSerializer -from osa.domain.data.serializer.json import JsonSerializer -from osa.domain.data.serializer.protocol import Serializer - - -@dataclass(frozen=True) -class DataResponseFormat: - serializer_cls: type[Serializer] - paginated: bool # True → JSON envelope + cursor; False → unbounded stream - suffix: str # "" (json), "csv", "csv.gz" - media_type: str - statement_timeout: str # "30s" paginated, "30min" streaming - - def make_serializer(self) -> Serializer: - return self.serializer_cls() - - -FORMATS: tuple[DataResponseFormat, ...] = ( - DataResponseFormat( - serializer_cls=JsonSerializer, - paginated=True, - suffix="", - media_type="application/json", - statement_timeout="30s", - ), - DataResponseFormat( - serializer_cls=CsvSerializer, - paginated=False, - suffix="csv", - media_type="text/csv", - statement_timeout="30min", - ), - DataResponseFormat( - serializer_cls=CsvGzipSerializer, - paginated=False, - suffix="csv.gz", - media_type="application/gzip", - statement_timeout="30min", - ), -) diff --git a/server/osa/domain/data/port/data_read_store.py b/server/osa/domain/data/port/data_read_store.py index 60a9c2d..d82ef3a 100644 --- a/server/osa/domain/data/port/data_read_store.py +++ b/server/osa/domain/data/port/data_read_store.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Mapping +from datetime import timedelta from typing import TYPE_CHECKING, Any, Protocol if TYPE_CHECKING: @@ -24,8 +25,14 @@ class DataReadStore(Protocol): - def stream_rows(self, plan: "QueryPlan") -> AsyncIterator[Mapping[str, Any]]: - """Stream projected rows for the plan via a server-side cursor.""" + def stream_rows( + self, plan: "QueryPlan", timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + """Stream projected rows for the plan via a server-side cursor. + + ``timeout`` is the caller's execution budget for the read; the adapter + enforces it on its own connection (the routes hold no SQL session). + """ ... async def get_record_by_id(self, id: "RecordId", version: int | None) -> "RecordSummary | None": diff --git a/server/osa/domain/data/service/data_query.py b/server/osa/domain/data/service/data_query.py index e5c346d..9437c67 100644 --- a/server/osa/domain/data/service/data_query.py +++ b/server/osa/domain/data/service/data_query.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import AsyncIterator, Iterator, Mapping +from datetime import timedelta from typing import Any from osa.config import Config @@ -29,18 +30,22 @@ class DataQueryService(Service): read_store: DataReadStore config: Config - async def stream_records(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + async def stream_records( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: if plan.table_kind != TableKind.RECORDS: raise ValidationError("stream_records requires a RECORDS plan", field="table_kind") self._validate_filter_bounds(plan.filter) - async for row in self.read_store.stream_rows(plan): + async for row in self.read_store.stream_rows(plan, timeout): yield row - async def stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + async def stream_features( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: if plan.table_kind != TableKind.FEATURE: raise ValidationError("stream_features requires a FEATURE plan", field="table_kind") self._validate_filter_bounds(plan.filter) - async for row in self.read_store.stream_rows(plan): + async for row in self.read_store.stream_rows(plan, timeout): yield row # ------------------------------------------------------------------ # diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index cb1258f..237b95c 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -16,7 +16,7 @@ import logging from collections.abc import AsyncIterator, Mapping -from datetime import date, datetime +from datetime import date, datetime, timedelta from typing import Any import sqlalchemy as sa @@ -97,7 +97,7 @@ "object": FieldType.TEXT, } -# All URL-exposed format suffixes (mirrors model.format.FORMATS suffixes). +# All URL-exposed format suffixes (mirrors the route-layer FORMATS registry). _ALL_FORMATS = ["", "csv", "csv.gz"] @@ -105,6 +105,16 @@ def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") +def statement_timeout_sql(timeout: timedelta) -> str: + """Render the caller's execution budget as a ``SET LOCAL`` statement. + + Integer milliseconds, so no operator- or caller-supplied string ever + reaches raw SQL. ``SET LOCAL`` reverts on commit/rollback, so the value + never leaks to a later request reusing the pooled connection. + """ + return f"SET LOCAL statement_timeout = '{int(timeout.total_seconds() * 1000)}ms'" + + def _invalid_cursor(exc: Exception) -> ValidationError: """Map a cursor decode/coerce ``ValueError`` to a 400, not a 500. @@ -129,7 +139,13 @@ def __init__(self, session: AsyncSession, node_domain: Domain) -> None: # Streaming primitive # ------------------------------------------------------------------ # - async def stream_rows(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + async def stream_rows( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: + # The generator body runs on the route's pre-flight ``__anext__`` pull, + # so the timeout is in place before the SELECT opens its cursor. + if timeout is not None: + await self.session.execute(sa.text(statement_timeout_sql(timeout))) if plan.table_kind == TableKind.RECORDS: async for row in self._stream_records(plan): yield row diff --git a/server/tests/unit/domain/data/serializer/__init__.py b/server/tests/unit/application/api/v1/routes/data/serializers/__init__.py similarity index 100% rename from server/tests/unit/domain/data/serializer/__init__.py rename to server/tests/unit/application/api/v1/routes/data/serializers/__init__.py diff --git a/server/tests/unit/domain/data/serializer/test_csv.py b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv.py similarity index 96% rename from server/tests/unit/domain/data/serializer/test_csv.py rename to server/tests/unit/application/api/v1/routes/data/serializers/test_csv.py index b3fdafb..f0bdb9d 100644 --- a/server/tests/unit/domain/data/serializer/test_csv.py +++ b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv.py @@ -8,7 +8,7 @@ import pytest from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.serializer.csv import CsvSerializer +from osa.application.api.v1.routes.data.serializers.csv import CsvSerializer from osa.domain.semantics.model.value import FieldType COLUMNS = [ diff --git a/server/tests/unit/domain/data/serializer/test_csv_gzip.py b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv_gzip.py similarity index 95% rename from server/tests/unit/domain/data/serializer/test_csv_gzip.py rename to server/tests/unit/application/api/v1/routes/data/serializers/test_csv_gzip.py index 77f983f..1d81eee 100644 --- a/server/tests/unit/domain/data/serializer/test_csv_gzip.py +++ b/server/tests/unit/application/api/v1/routes/data/serializers/test_csv_gzip.py @@ -9,7 +9,7 @@ import pytest from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.serializer.csv_gzip import CsvGzipSerializer +from osa.application.api.v1.routes.data.serializers.csv_gzip import CsvGzipSerializer from osa.domain.semantics.model.value import FieldType COLUMNS = [ diff --git a/server/tests/unit/domain/data/serializer/test_json.py b/server/tests/unit/application/api/v1/routes/data/serializers/test_json.py similarity index 95% rename from server/tests/unit/domain/data/serializer/test_json.py rename to server/tests/unit/application/api/v1/routes/data/serializers/test_json.py index 33eed01..084122e 100644 --- a/server/tests/unit/domain/data/serializer/test_json.py +++ b/server/tests/unit/application/api/v1/routes/data/serializers/test_json.py @@ -7,7 +7,7 @@ import pytest from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.serializer.json import JsonSerializer +from osa.application.api.v1.routes.data.serializers.json import JsonSerializer from osa.domain.semantics.model.value import FieldType COLUMNS = [ diff --git a/server/tests/unit/application/api/v1/routes/data/test_streaming.py b/server/tests/unit/application/api/v1/routes/data/test_streaming.py index 21a6a0f..5d5c255 100644 --- a/server/tests/unit/application/api/v1/routes/data/test_streaming.py +++ b/server/tests/unit/application/api/v1/routes/data/test_streaming.py @@ -11,7 +11,7 @@ import pytest from osa.application.api.v1.routes.data._streaming import build_table_response -from osa.domain.data.model.format import FORMATS +from osa.application.api.v1.routes.data.formats import FORMATS from osa.domain.data.model.manifest import ColumnSpec from osa.domain.data.model.query_plan import ( QueryPlan, diff --git a/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py b/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py index 2d2e2f5..b7ef790 100644 --- a/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py +++ b/server/tests/unit/application/api/v1/routes/data/test_tables_factory.py @@ -8,7 +8,7 @@ path_for, register_table_routes, ) -from osa.domain.data.model.format import FORMATS +from osa.application.api.v1.routes.data.formats import FORMATS def _noop_builder(_fmt): diff --git a/server/tests/unit/domain/data/test_data_query_service.py b/server/tests/unit/domain/data/test_data_query_service.py index f0b0cf4..5966285 100644 --- a/server/tests/unit/domain/data/test_data_query_service.py +++ b/server/tests/unit/domain/data/test_data_query_service.py @@ -5,6 +5,7 @@ from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass, field +from datetime import timedelta from typing import Any import pytest @@ -38,9 +39,13 @@ class FakeReadStore: def __init__(self, rows: list[Mapping[str, Any]]) -> None: self._rows = rows self.received_plan: QueryPlan | None = None + self.received_timeout: timedelta | None = None - async def stream_rows(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, Any]]: + async def stream_rows( + self, plan: QueryPlan, timeout: timedelta | None = None + ) -> AsyncIterator[Mapping[str, Any]]: self.received_plan = plan + self.received_timeout = timeout for row in self._rows: yield row @@ -117,3 +122,22 @@ async def test_no_filter_streams_cleanly() -> None: service, _ = _service([{"id": "x"}]) plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) assert await _drain(service.stream_records(plan)) == [{"id": "x"}] + + +@pytest.mark.asyncio +async def test_stream_records_passes_timeout_to_store() -> None: + # The statement timeout is an execution budget owned by the caller (the + # route's format registry); the service forwards it to the port so the + # adapter — the only layer holding a SQL session — can apply it. + service, store = _service([{"id": "a"}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + await _drain(service.stream_records(plan, timeout=timedelta(seconds=30))) + assert store.received_timeout == timedelta(seconds=30) + + +@pytest.mark.asyncio +async def test_stream_features_passes_timeout_to_store() -> None: + service, store = _service([{"id": 1}]) + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + await _drain(service.stream_features(plan, timeout=timedelta(minutes=30))) + assert store.received_timeout == timedelta(minutes=30) diff --git a/server/tests/unit/infrastructure/data/test_statement_timeout.py b/server/tests/unit/infrastructure/data/test_statement_timeout.py new file mode 100644 index 0000000..961fba6 --- /dev/null +++ b/server/tests/unit/infrastructure/data/test_statement_timeout.py @@ -0,0 +1,20 @@ +"""The adapter owns SET LOCAL statement_timeout (routes hold no SQL session).""" + +from datetime import timedelta + +from osa.infrastructure.data.postgres_data_read_store import statement_timeout_sql + + +def test_statement_timeout_sql_renders_integer_milliseconds() -> None: + assert statement_timeout_sql(timedelta(seconds=30)) == "SET LOCAL statement_timeout = '30000ms'" + assert ( + statement_timeout_sql(timedelta(minutes=30)) == "SET LOCAL statement_timeout = '1800000ms'" + ) + + +def test_statement_timeout_sql_truncates_subsecond() -> None: + # int milliseconds — no operator-supplied string ever reaches raw SQL. + assert ( + statement_timeout_sql(timedelta(milliseconds=1500.7)) + == "SET LOCAL statement_timeout = '1500ms'" + ) From b70c4ae439b2cd6559c973a2a1702c45de4c9c7a Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:54:38 +0100 Subject: [PATCH 18/23] refactor: route /stats through a GetStats query handler instead of injecting RecordRepository --- server/osa/application/api/v1/routes/stats.py | 8 +++---- server/osa/domain/record/query/get_stats.py | 21 ++++++++++++++++ server/osa/domain/record/service/record.py | 4 ++++ server/osa/infrastructure/persistence/di.py | 2 ++ .../domain/record/test_get_stats_handler.py | 24 +++++++++++++++++++ 5 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 server/osa/domain/record/query/get_stats.py create mode 100644 server/tests/unit/domain/record/test_get_stats_handler.py diff --git a/server/osa/application/api/v1/routes/stats.py b/server/osa/application/api/v1/routes/stats.py index 64185ad..64ee3ec 100644 --- a/server/osa/application/api/v1/routes/stats.py +++ b/server/osa/application/api/v1/routes/stats.py @@ -4,7 +4,7 @@ from fastapi import APIRouter from pydantic import BaseModel -from osa.domain.record.port.repository import RecordRepository +from osa.domain.record.query.get_stats import GetStats, GetStatsHandler router = APIRouter( prefix="/stats", @@ -28,8 +28,8 @@ class StatsResponse(BaseModel): @router.get("") async def get_stats( - record_repo: FromDishka[RecordRepository], + handler: FromDishka[GetStatsHandler], ) -> StatsResponse: """Get system statistics.""" - record_count = await record_repo.count() - return StatsResponse(records=record_count) + result = await handler.run(GetStats()) + return StatsResponse(records=result.records) diff --git a/server/osa/domain/record/query/get_stats.py b/server/osa/domain/record/query/get_stats.py new file mode 100644 index 0000000..92b04f9 --- /dev/null +++ b/server/osa/domain/record/query/get_stats.py @@ -0,0 +1,21 @@ +"""GetStats query handler — public node statistics (record count).""" + +from osa.domain.record.service.record import RecordService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.query import Query, QueryHandler, Result + + +class GetStats(Query): + pass + + +class StatsResult(Result): + records: int + + +class GetStatsHandler(QueryHandler[GetStats, StatsResult]): + __auth__ = public() + record_service: RecordService + + async def run(self, cmd: GetStats) -> StatsResult: + return StatsResult(records=await self.record_service.count()) diff --git a/server/osa/domain/record/service/record.py b/server/osa/domain/record/service/record.py index e84512b..9c7d1f0 100644 --- a/server/osa/domain/record/service/record.py +++ b/server/osa/domain/record/service/record.py @@ -55,6 +55,10 @@ async def get(self, srn: RecordSRN) -> Record: raise NotFoundError(f"Record not found: {srn}") return record + async def count(self) -> int: + """Total published records on this node.""" + return await self.record_repo.count() + async def _resolve_schema_id(self, convention_srn: ConventionSRN) -> SchemaId: """Resolve a convention to its schema id at publication time.""" convention = await self.convention_repo.get(convention_srn) diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 9f5316b..60fdc19 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -16,6 +16,7 @@ from osa.domain.record.port.feature_reader import FeatureReader from osa.domain.record.port.repository import RecordRepository from osa.domain.record.query.get_record import GetRecordHandler +from osa.domain.record.query.get_stats import GetStatsHandler from osa.domain.record.service import RecordService from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader from osa.domain.feature.port.storage import FeatureStoragePort @@ -176,3 +177,4 @@ def get_data_read_store(self, session: AsyncSession, config: Config) -> Postgres # Record query handlers get_record_handler = provide(GetRecordHandler, scope=Scope.UOW) + get_stats_handler = provide(GetStatsHandler, scope=Scope.UOW) diff --git a/server/tests/unit/domain/record/test_get_stats_handler.py b/server/tests/unit/domain/record/test_get_stats_handler.py new file mode 100644 index 0000000..985c83d --- /dev/null +++ b/server/tests/unit/domain/record/test_get_stats_handler.py @@ -0,0 +1,24 @@ +"""GetStats query handler — the stats route must not touch RecordRepository. + +Layering regression guard: Router → QueryHandler → Service → Repository. The +/stats route previously injected RecordRepository directly. +""" + +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.record.query.get_stats import GetStats, GetStatsHandler + + +class TestGetStatsHandler: + @pytest.mark.asyncio + async def test_returns_record_count_from_service(self): + service = AsyncMock() + service.count.return_value = 42 + + handler = GetStatsHandler(record_service=service) + result = await handler.run(GetStats()) + + assert result.records == 42 + service.count.assert_awaited_once() From d29e8e8a78fbb545ef894126dfab560077e0033a Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 13:58:19 +0100 Subject: [PATCH 19/23] refactor: Keyset on QueryPlan owns the pagination contract for both encode and decode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tiebreak column choice and the sort=id aliasing rule were re-derived independently by _next_cursor (route) and the store's sort builders — the split that produced three cursor bugs. plan.keyset is now the single source of truth: the store maps its column names to SQL expressions, the response assembly calls keyset.cursor_from_row. --- .../api/v1/routes/data/_streaming.py | 21 ++----- server/osa/domain/data/model/query_plan.py | 47 +++++++++++++++ .../data/postgres_data_read_store.py | 47 ++++++++------- .../tests/unit/domain/data/test_query_plan.py | 57 +++++++++++++++++++ 4 files changed, 136 insertions(+), 36 deletions(-) diff --git a/server/osa/application/api/v1/routes/data/_streaming.py b/server/osa/application/api/v1/routes/data/_streaming.py index 142df6e..09bfca7 100644 --- a/server/osa/application/api/v1/routes/data/_streaming.py +++ b/server/osa/application/api/v1/routes/data/_streaming.py @@ -22,7 +22,7 @@ from osa.application.api.v1.routes.data.formats import DataResponseFormat from osa.domain.data.model.manifest import ColumnSpec -from osa.domain.data.model.query_plan import QueryPlan, TableKind, encode_cursor +from osa.domain.data.model.query_plan import QueryPlan async def build_table_response( @@ -91,18 +91,7 @@ async def page_iter() -> AsyncIterator[Mapping[str, Any]]: def _next_cursor(page: list[Mapping[str, Any]], plan: QueryPlan) -> str | None: - if not page: - return None - last = page[-1] - # The keyset tiebreaker is the records srn / feature id (see - # PostgresDataReadStore). ``sort=id`` aliases to that same column in the - # store, so the cursor's sort value must be the tiebreaker value too — - # encoding the bare record id would compare it against the srn column, - # match every row, and never advance. Select the key by table kind, not - # by key presence: a hook may declare a data column named "srn", which - # must not hijack the feature tiebreaker. - tiebreak_key = "srn" if plan.table_kind == TableKind.RECORDS else "id" - tiebreak = last.get(tiebreak_key) - sort_column = plan.sort[0].column - sort_value = tiebreak if sort_column == "id" else last.get(sort_column) - return encode_cursor(sort_value, tiebreak) + # Tiebreak selection and sort=id aliasing live on plan.keyset — the same + # object the store builds its ORDER BY / after-condition from, so the + # encode side cannot drift from the decode side. + return plan.keyset.cursor_from_row(page[-1]) if page else None diff --git a/server/osa/domain/data/model/query_plan.py b/server/osa/domain/data/model/query_plan.py index f2eb846..72e8295 100644 --- a/server/osa/domain/data/model/query_plan.py +++ b/server/osa/domain/data/model/query_plan.py @@ -15,6 +15,7 @@ import base64 import json +from collections.abc import Mapping from enum import StrEnum from typing import Any @@ -73,6 +74,37 @@ def clamped( return cls(cursor=cursor, limit=max(1, min(limit, max_limit))) +class Keyset(BaseModel): + """The keyset-pagination contract for a plan — the single source of truth + for which column is the effective primary sort and which breaks ties. + + Consumed by BOTH sides of the pagination round-trip: the Postgres adapter + (ORDER BY + the cursor's ``after`` condition) and the response assembly + (encoding ``next_cursor`` from the last page row). Keeping the two sides + on one object is what makes encode/decode asymmetries unrepresentable. + """ + + sort_column: str # effective primary sort, post-aliasing (``id`` → tiebreak) + tiebreak_column: str # records: ``srn`` (the PK); features: ``id`` (BIGSERIAL) + + def cursor_from_row(self, row: Mapping[str, Any]) -> str: + """Encode the opaque ``next_cursor`` from the last row of a page.""" + tiebreak = row.get(self.tiebreak_column) + sort_value = ( + tiebreak if self.sort_column == self.tiebreak_column else row.get(self.sort_column) + ) + return encode_cursor(sort_value, tiebreak) + + +# Keyset tiebreak column per table kind. Records tiebreak on ``srn`` — the PK; +# a bare record id is NOT unique across versions. Features tiebreak on the +# BIGSERIAL ``id``. A hook may legally declare a data column named ``srn``, +# which must never hijack the feature tiebreaker — hence by-kind, not by-key. +_TIEBREAK_COLUMNS: dict[TableKind, str] = { + TableKind.RECORDS: "srn", + TableKind.FEATURE: "id", +} + # Default sort keys per table kind (data-model.md §PaginationParams). _DEFAULT_SORTS: dict[TableKind, list[SortSpec]] = { TableKind.RECORDS: [ @@ -103,6 +135,21 @@ def _validate_and_default(self) -> "QueryPlan": self.sort = list(_DEFAULT_SORTS[self.table_kind]) return self + @property + def keyset(self) -> Keyset: + """The pagination contract for this plan. + + ``sort=id`` aliases to the tiebreak column: for records the addressable + ``id`` is a projection of the srn PK (sorting by it would be ambiguous + across versions), and for features ``id`` already is the tiebreaker. + """ + tiebreak = _TIEBREAK_COLUMNS[self.table_kind] + primary = self.sort[0].column + return Keyset( + sort_column=tiebreak if primary == "id" else primary, + tiebreak_column=tiebreak, + ) + def encode_cursor(sort_value: Any, id_value: Any) -> str: """Encode a cursor as urlsafe base64 of ``{"s": sort_value, "id": id_value}``. diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_data_read_store.py index 237b95c..a336e26 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_data_read_store.py @@ -212,26 +212,30 @@ def _records_row_to_mapping(self, row: RowMapping, col_names: list[str]) -> dict def _records_sort(self, plan: QueryPlan, metadata_table: Any) -> tuple[list[Any], Any | None]: t = records_table - # First sort key drives the cursor value; ``id`` (srn) is the tiebreaker. - primary = plan.sort[0] - is_desc = primary.direction == SortDirection.DESC - if primary.column in ("created_at", "published_at"): + # plan.keyset owns tiebreak choice and sort=id aliasing; this method + # only maps its column names onto SQL expressions. + keyset = plan.keyset + is_desc = plan.sort[0].direction == SortDirection.DESC + tiebreak_expr = t.c[keyset.tiebreak_column] + if keyset.sort_column == keyset.tiebreak_column: + sort_expr: sa.ColumnElement[Any] = tiebreak_expr + elif keyset.sort_column in ("created_at", "published_at"): sort_expr = t.c.published_at - elif primary.column == "id": - sort_expr = t.c.srn - elif primary.column in metadata_table.c: - sort_expr = metadata_table.c[primary.column] + elif keyset.sort_column in metadata_table.c: + sort_expr = metadata_table.c[keyset.sort_column] else: raise ValidationError( - f"Unknown sort column '{primary.column}'.", field="sort", code="unknown_sort_field" + f"Unknown sort column '{keyset.sort_column}'.", + field="sort", + code="unknown_sort_field", ) page = KeysetPage( [ SortKey(sort_expr, descending=is_desc, nulls_last=True), - SortKey(t.c.srn, descending=is_desc), + SortKey(tiebreak_expr, descending=is_desc), ] ) - return page.order_by(), self._cursor_after(plan, page, sort_expr, t.c.srn) + return page.order_by(), self._cursor_after(plan, page, sort_expr, tiebreak_expr) def _cursor_after( self, @@ -382,25 +386,28 @@ async def _feature_count(self, ft: sa.Table, schema_id: SchemaId) -> int: return int((await self.session.execute(stmt)).scalar_one()) def _features_sort(self, plan: QueryPlan, ft: sa.Table) -> tuple[list[Any], Any | None]: - primary = plan.sort[0] - is_desc = primary.direction == SortDirection.DESC - if primary.column == "id": - sort_expr = ft.c.id - elif primary.column in ft.c: - sort_expr = ft.c[primary.column] + # plan.keyset owns tiebreak choice and sort=id aliasing; this method + # only maps its column names onto SQL expressions. + keyset = plan.keyset + is_desc = plan.sort[0].direction == SortDirection.DESC + tiebreak_expr = ft.c[keyset.tiebreak_column] + if keyset.sort_column == keyset.tiebreak_column: + sort_expr: sa.ColumnElement[Any] = tiebreak_expr + elif keyset.sort_column in ft.c: + sort_expr = ft.c[keyset.sort_column] else: raise ValidationError( - f"Unknown sort column '{primary.column}'.", + f"Unknown sort column '{keyset.sort_column}'.", field="sort", code="unknown_sort_field", ) page = KeysetPage( [ SortKey(sort_expr, descending=is_desc, nulls_last=True), - SortKey(ft.c.id, descending=is_desc), + SortKey(tiebreak_expr, descending=is_desc), ] ) - return page.order_by(), self._cursor_after(plan, page, sort_expr, ft.c.id) + return page.order_by(), self._cursor_after(plan, page, sort_expr, tiebreak_expr) def _compile_feature_filter(self, expr: FilterExpr, *, ft: sa.Table, feature_name: str) -> Any: if isinstance(expr, Predicate): diff --git a/server/tests/unit/domain/data/test_query_plan.py b/server/tests/unit/domain/data/test_query_plan.py index 3419718..963abf8 100644 --- a/server/tests/unit/domain/data/test_query_plan.py +++ b/server/tests/unit/domain/data/test_query_plan.py @@ -10,6 +10,7 @@ PaginationParams, QueryPlan, SortDirection, + SortSpec, TableKind, decode_cursor, encode_cursor, @@ -85,3 +86,59 @@ def test_encode_cursor_accepts_datetime_sort_value() -> None: decoded = decode_cursor(encode_cursor(ts, 7)) assert datetime.fromisoformat(decoded["s"]) == ts assert decoded["id"] == 7 + + +# --------------------------------------------------------------------------- # +# Keyset — single owner of the pagination contract (tiebreak + sort aliasing) +# --------------------------------------------------------------------------- # + + +def test_records_keyset_defaults() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + assert plan.keyset.tiebreak_column == "srn" + assert plan.keyset.sort_column == "created_at" + + +def test_records_keyset_id_sort_aliases_to_tiebreak() -> None: + # ``sort=id`` on records sorts by the srn column (the PK; a bare id is not + # unique across versions), so the keyset's sort column IS the tiebreaker. + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + assert plan.keyset.sort_column == "srn" + assert plan.keyset.tiebreak_column == "srn" + + +def test_feature_keyset_defaults() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + assert plan.keyset.tiebreak_column == "id" + assert plan.keyset.sort_column == "id" + + +def test_keyset_cursor_from_row_records_default_sort() -> None: + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.RECORDS) + cursor = plan.keyset.cursor_from_row({"id": "a", "srn": "s1", "created_at": "2026-01-02"}) + assert decode_cursor(cursor) == {"s": "2026-01-02", "id": "s1"} + + +def test_keyset_cursor_from_row_records_id_sort_encodes_srn() -> None: + plan = QueryPlan( + schema_id=SCHEMA, + table_kind=TableKind.RECORDS, + sort=[SortSpec(column="id", direction=SortDirection.ASC)], + ) + cursor = plan.keyset.cursor_from_row({"id": "a", "srn": "urn:osa:localhost:rec:a@1"}) + assert decode_cursor(cursor) == { + "s": "urn:osa:localhost:rec:a@1", + "id": "urn:osa:localhost:rec:a@1", + } + + +def test_keyset_cursor_from_row_feature_hook_column_named_srn_is_inert() -> None: + # A hook may declare a data column named "srn"; the feature tiebreaker is + # still the integer row id. + plan = QueryPlan(schema_id=SCHEMA, table_kind=TableKind.FEATURE, feature_name="f") + cursor = plan.keyset.cursor_from_row({"id": 2, "srn": "hook-value-2"}) + assert decode_cursor(cursor) == {"s": 2, "id": 2} From ee7ae1398136458461b5bb9511bda1077daac8b0 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 14:04:37 +0100 Subject: [PATCH 20/23] =?UTF-8?q?feat:=20restore=20Router=20=E2=86=92=20Qu?= =?UTF-8?q?eryHandler=20=E2=86=92=20Service=20layering=20on=20the=20/data/?= =?UTF-8?q?=20surface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The data domain's query/ folder was empty and routes called services directly — the only domain bypassing the documented entry-point pattern, and with it the __auth__ gate the deferred private-schema auth model will need. ReadRecordsTable/ReadFeatureTable handlers now own table resolution, plan construction, and config-driven limit clamping (routes no longer inject Config); GetNodeCatalog/GetSchemaManifest/GetDataRecord cover the catalog reads. The QueryHandler result TypeVar is unbounded so handlers can return domain read models and streaming reads without ceremonial wrappers. --- .../application/api/v1/routes/data/_params.py | 43 +----- .../application/api/v1/routes/data/catalog.py | 20 +-- .../api/v1/routes/data/features_table.py | 67 ++++------ .../application/api/v1/routes/data/records.py | 7 +- .../api/v1/routes/data/records_table.py | 65 ++++----- server/osa/domain/data/query/catalog.py | 48 +++++++ server/osa/domain/data/query/read_table.py | 107 +++++++++++++++ server/osa/domain/data/util/di/provider.py | 17 ++- server/osa/domain/shared/query.py | 5 +- .../api/v1/routes/data/test_params.py | 56 +------- .../domain/data/test_table_read_handlers.py | 124 ++++++++++++++++++ 11 files changed, 377 insertions(+), 182 deletions(-) create mode 100644 server/osa/domain/data/query/catalog.py create mode 100644 server/osa/domain/data/query/read_table.py create mode 100644 server/tests/unit/domain/data/test_table_read_handlers.py diff --git a/server/osa/application/api/v1/routes/data/_params.py b/server/osa/application/api/v1/routes/data/_params.py index 73e0d05..55ddbf0 100644 --- a/server/osa/application/api/v1/routes/data/_params.py +++ b/server/osa/application/api/v1/routes/data/_params.py @@ -1,21 +1,12 @@ -"""Shared request parsing for table routes — sort spec, filter body, plan build.""" +"""Shared request parsing for table routes — sort spec + filter body.""" from __future__ import annotations from pydantic import BaseModel from osa.domain.data.model.filter import FilterExpr -from osa.domain.data.model.query_plan import ( - PaginationCursor, - PaginationParams, - QueryPlan, - SortDirection, - SortSpec, - TableKind, -) +from osa.domain.data.model.query_plan import SortDirection, SortSpec from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.ids import HookName -from osa.domain.shared.model.srn import SchemaId class FilterRequestBody(BaseModel): @@ -23,8 +14,8 @@ class FilterRequestBody(BaseModel): filter: FilterExpr | None = None cursor: str | None = None - # Unbounded at the edge: build_plan clamps to [1, DataConfig.max_page_limit] - # via PaginationParams.clamped — over-large requests get the max page, not a 422. + # Unbounded at the edge: the table-read handler clamps to + # [1, DataConfig.max_page_limit] — over-large requests get the max page, not a 422. limit: int = 50 sort: str | None = None @@ -51,29 +42,3 @@ def parse_sort(raw: str | None) -> list[SortSpec]: column, dir_enum = token, SortDirection.ASC specs.append(SortSpec(column=column.strip(), direction=dir_enum)) return specs - - -def build_plan( - *, - schema_id: SchemaId, - table_kind: TableKind, - feature_name: HookName | None, - filter_expr: FilterExpr | None, - cursor: str | None, - limit: int, - max_limit: int, - sort: str | None, -) -> QueryPlan: - pagination = PaginationParams.clamped( - cursor=PaginationCursor(value=cursor) if cursor else None, - limit=limit, - max_limit=max_limit, - ) - return QueryPlan( - schema_id=schema_id, - table_kind=table_kind, - feature_name=feature_name, - filter=filter_expr, - pagination=pagination, - sort=parse_sort(sort), - ) diff --git a/server/osa/application/api/v1/routes/data/catalog.py b/server/osa/application/api/v1/routes/data/catalog.py index 8b1a3c5..b8007f9 100644 --- a/server/osa/application/api/v1/routes/data/catalog.py +++ b/server/osa/application/api/v1/routes/data/catalog.py @@ -1,7 +1,7 @@ -"""Catalog & manifest handlers — ``GET /data`` and ``GET /data/{schema}``. +"""Catalog & manifest routes — ``GET /data`` and ``GET /data/{schema}``. JSON-only (no format suffix). Reserved schema names (``records``, ``datasets``) -and unknown schemas surface as 404 via the service's ``NotFoundError``. +and unknown schemas surface as 404 via the handler's ``NotFoundError``. The node-catalog handler is registered directly on the prefixed ``/data`` router (an empty sub-path can't be ``include_router``-ed); the manifest handler @@ -16,22 +16,26 @@ from osa.domain.data.model.catalog import NodeCatalog from osa.domain.data.model.manifest import SchemaManifest -from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.query.catalog import ( + GetNodeCatalog, + GetNodeCatalogHandler, + GetSchemaManifest, + GetSchemaManifestHandler, +) manifest_router = APIRouter(route_class=DishkaRoute) -async def get_node_catalog(service: FromDishka[DataCatalogService]) -> NodeCatalog: +async def get_node_catalog(handler: FromDishka[GetNodeCatalogHandler]) -> NodeCatalog: """List schemas hosted at this node.""" - return await service.get_node_catalog() + return await handler.run(GetNodeCatalog()) @manifest_router.get( "/{schema}", operation_id="data_get_schema_manifest", response_model=SchemaManifest ) async def get_schema_manifest( - schema: str, service: FromDishka[DataCatalogService] + schema: str, handler: FromDishka[GetSchemaManifestHandler] ) -> SchemaManifest: """Machine-readable manifest for a schema (`` or `@`).""" - schema_id = await service.resolve_schema(schema) - return await service.get_schema_manifest(schema_id) + return await handler.run(GetSchemaManifest(schema=schema)) diff --git a/server/osa/application/api/v1/routes/data/features_table.py b/server/osa/application/api/v1/routes/data/features_table.py index ef8a124..a23bc57 100644 --- a/server/osa/application/api/v1/routes/data/features_table.py +++ b/server/osa/application/api/v1/routes/data/features_table.py @@ -1,9 +1,8 @@ """Feature-table routes — ``/data/{schema}/{feature}[.csv|.csv.gz]`` (US5). Identical shape to the records table — same factory, same streaming/pagination -engine — but the path carries a ``{feature}`` segment and the plan is a -``TableKind.FEATURE`` plan. The feature must be a table resource on the resolved -schema's manifest; an unknown feature 404s before any bytes (T090). +engine — but the path carries a ``{feature}`` segment and the handler builds a +``TableKind.FEATURE`` plan. An unknown feature 404s before any bytes (T090). """ from __future__ import annotations @@ -13,40 +12,33 @@ from fastapi.responses import StreamingResponse from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter -from osa.application.api.v1.routes.data._params import FilterRequestBody, build_plan +from osa.application.api.v1.routes.data._params import FilterRequestBody, parse_sort from osa.application.api.v1.routes.data._streaming import build_table_response -from osa.application.api.v1.routes.data.tables import format_key, register_table_routes -from osa.config import Config from osa.application.api.v1.routes.data.formats import DataResponseFormat -from osa.domain.data.model.query_plan import TableKind -from osa.domain.data.service.data_catalog import DataCatalogService -from osa.domain.data.service.data_query import DataQueryService +from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.domain.data.query.read_table import ReadFeatureTable, ReadFeatureTableHandler def _make_get_endpoint(fmt: DataResponseFormat): async def endpoint( schema: str, feature: str, - query_service: FromDishka[DataQueryService], - catalog_service: FromDishka[DataCatalogService], - config: FromDishka[Config], + handler: FromDishka[ReadFeatureTableHandler], cursor: str | None = None, limit: int = 50, sort: str | None = None, ) -> StreamingResponse: - table = await catalog_service.resolve_table(schema, TableKind.FEATURE, feature_name=feature) - plan = build_plan( - schema_id=table.schema_id, - table_kind=TableKind.FEATURE, - feature_name=feature, - filter_expr=None, - cursor=cursor, - limit=limit, - max_limit=config.data.max_page_limit, - sort=sort, + result = await handler.run( + ReadFeatureTable( + schema=schema, + feature=feature, + cursor=cursor, + limit=limit, + sort=parse_sort(sort), + timeout=fmt.timeout, + ) ) - rows = query_service.stream_features(plan, timeout=fmt.timeout) - return await build_table_response(rows, fmt, table.columns, plan) + return await build_table_response(result.rows, fmt, result.columns, result.plan) return endpoint @@ -57,23 +49,20 @@ async def endpoint( schema: str, feature: str, body: FilterRequestBody, - query_service: FromDishka[DataQueryService], - catalog_service: FromDishka[DataCatalogService], - config: FromDishka[Config], + handler: FromDishka[ReadFeatureTableHandler], ) -> StreamingResponse: - table = await catalog_service.resolve_table(schema, TableKind.FEATURE, feature_name=feature) - plan = build_plan( - schema_id=table.schema_id, - table_kind=TableKind.FEATURE, - feature_name=feature, - filter_expr=body.filter, - cursor=body.cursor, - limit=body.limit, - max_limit=config.data.max_page_limit, - sort=body.sort, + result = await handler.run( + ReadFeatureTable( + schema=schema, + feature=feature, + filter=body.filter, + cursor=body.cursor, + limit=body.limit, + sort=parse_sort(body.sort), + timeout=fmt.timeout, + ) ) - rows = query_service.stream_features(plan, timeout=fmt.timeout) - return await build_table_response(rows, fmt, table.columns, plan) + return await build_table_response(result.rows, fmt, result.columns, result.plan) # Unique name before the limiter — see records_table._make_post_endpoint. endpoint.__name__ = f"feature_post_{format_key(fmt)}" diff --git a/server/osa/application/api/v1/routes/data/records.py b/server/osa/application/api/v1/routes/data/records.py index 40aad6d..8790647 100644 --- a/server/osa/application/api/v1/routes/data/records.py +++ b/server/osa/application/api/v1/routes/data/records.py @@ -12,7 +12,7 @@ from fastapi import APIRouter from osa.application.api.v1.routes.data.models import RecordResponse -from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.query.catalog import GetDataRecord, GetDataRecordHandler from osa.domain.shared.model.ids import RecordRef router = APIRouter(route_class=DishkaRoute) @@ -22,8 +22,7 @@ "/records/{record_id}", operation_id="data_get_record_by_id", response_model=RecordResponse ) async def get_record_by_id( - record_id: str, service: FromDishka[DataCatalogService] + record_id: str, handler: FromDishka[GetDataRecordHandler] ) -> RecordResponse: - ref = RecordRef.parse(record_id) - summary = await service.get_record_by_id(ref.id, ref.version) + summary = await handler.run(GetDataRecord(ref=RecordRef.parse(record_id))) return RecordResponse.from_summary(summary) diff --git a/server/osa/application/api/v1/routes/data/records_table.py b/server/osa/application/api/v1/routes/data/records_table.py index 6420fd3..73318e7 100644 --- a/server/osa/application/api/v1/routes/data/records_table.py +++ b/server/osa/application/api/v1/routes/data/records_table.py @@ -3,8 +3,9 @@ Endpoint *builders* capture a :class:`DataResponseFormat` by closure; the generic :func:`register_table_routes` factory registers the GET/POST × format matrix from them. GET streams (or paginates JSON) with no body; POST takes a -``FilterExpr`` body and is rate-limited per IP. Every route applies the -format's ``statement_timeout`` and resolves the schema (404 before bytes). +``FilterExpr`` body and is rate-limited per IP. Everything between HTTP and +the row stream — table resolution, plan construction, limit clamping — lives +in :class:`ReadRecordsTableHandler`. """ from __future__ import annotations @@ -14,39 +15,31 @@ from fastapi.responses import StreamingResponse from osa.application.api.v1.routes.data._limiter import POST_RATE_LIMIT, limiter -from osa.application.api.v1.routes.data._params import FilterRequestBody, build_plan +from osa.application.api.v1.routes.data._params import FilterRequestBody, parse_sort from osa.application.api.v1.routes.data._streaming import build_table_response -from osa.application.api.v1.routes.data.tables import format_key, register_table_routes -from osa.config import Config from osa.application.api.v1.routes.data.formats import DataResponseFormat -from osa.domain.data.model.query_plan import TableKind -from osa.domain.data.service.data_catalog import DataCatalogService -from osa.domain.data.service.data_query import DataQueryService +from osa.application.api.v1.routes.data.tables import format_key, register_table_routes +from osa.domain.data.query.read_table import ReadRecordsTable, ReadRecordsTableHandler def _make_get_endpoint(fmt: DataResponseFormat): async def endpoint( schema: str, - query_service: FromDishka[DataQueryService], - catalog_service: FromDishka[DataCatalogService], - config: FromDishka[Config], + handler: FromDishka[ReadRecordsTableHandler], cursor: str | None = None, limit: int = 50, sort: str | None = None, ) -> StreamingResponse: - table = await catalog_service.resolve_table(schema, TableKind.RECORDS) - plan = build_plan( - schema_id=table.schema_id, - table_kind=TableKind.RECORDS, - feature_name=None, - filter_expr=None, - cursor=cursor, - limit=limit, - max_limit=config.data.max_page_limit, - sort=sort, + result = await handler.run( + ReadRecordsTable( + schema=schema, + cursor=cursor, + limit=limit, + sort=parse_sort(sort), + timeout=fmt.timeout, + ) ) - rows = query_service.stream_records(plan, timeout=fmt.timeout) - return await build_table_response(rows, fmt, table.columns, plan) + return await build_table_response(result.rows, fmt, result.columns, result.plan) return endpoint @@ -56,23 +49,19 @@ async def endpoint( request: Request, schema: str, body: FilterRequestBody, - query_service: FromDishka[DataQueryService], - catalog_service: FromDishka[DataCatalogService], - config: FromDishka[Config], + handler: FromDishka[ReadRecordsTableHandler], ) -> StreamingResponse: - table = await catalog_service.resolve_table(schema, TableKind.RECORDS) - plan = build_plan( - schema_id=table.schema_id, - table_kind=TableKind.RECORDS, - feature_name=None, - filter_expr=body.filter, - cursor=body.cursor, - limit=body.limit, - max_limit=config.data.max_page_limit, - sort=body.sort, + result = await handler.run( + ReadRecordsTable( + schema=schema, + filter=body.filter, + cursor=body.cursor, + limit=body.limit, + sort=parse_sort(body.sort), + timeout=fmt.timeout, + ) ) - rows = query_service.stream_records(plan, timeout=fmt.timeout) - return await build_table_response(rows, fmt, table.columns, plan) + return await build_table_response(result.rows, fmt, result.columns, result.plan) # slowapi scopes a rate limit by the decorated function's ``module.__name__`` # captured at decoration time. These builder closures are all named diff --git a/server/osa/domain/data/query/catalog.py b/server/osa/domain/data/query/catalog.py new file mode 100644 index 0000000..85d864b --- /dev/null +++ b/server/osa/domain/data/query/catalog.py @@ -0,0 +1,48 @@ +"""Catalog-shaped query handlers — node catalog, schema manifest, record by id.""" + +from __future__ import annotations + +from osa.domain.data.model.catalog import NodeCatalog +from osa.domain.data.model.manifest import SchemaManifest +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.model.ids import RecordRef +from osa.domain.shared.query import Query, QueryHandler + + +class GetNodeCatalog(Query): + pass + + +class GetNodeCatalogHandler(QueryHandler[GetNodeCatalog, NodeCatalog]): + __auth__ = public() + catalog_service: DataCatalogService + + async def run(self, cmd: GetNodeCatalog) -> NodeCatalog: + return await self.catalog_service.get_node_catalog() + + +class GetSchemaManifest(Query): + schema: str # URL segment: ```` or ``@`` + + +class GetSchemaManifestHandler(QueryHandler[GetSchemaManifest, SchemaManifest]): + __auth__ = public() + catalog_service: DataCatalogService + + async def run(self, cmd: GetSchemaManifest) -> SchemaManifest: + schema_id = await self.catalog_service.resolve_schema(cmd.schema) + return await self.catalog_service.get_schema_manifest(schema_id) + + +class GetDataRecord(Query): + ref: RecordRef + + +class GetDataRecordHandler(QueryHandler[GetDataRecord, RecordSummary]): + __auth__ = public() + catalog_service: DataCatalogService + + async def run(self, cmd: GetDataRecord) -> RecordSummary: + return await self.catalog_service.get_record_by_id(cmd.ref.id, cmd.ref.version) diff --git a/server/osa/domain/data/query/read_table.py b/server/osa/domain/data/query/read_table.py new file mode 100644 index 0000000..38b19d1 --- /dev/null +++ b/server/osa/domain/data/query/read_table.py @@ -0,0 +1,107 @@ +"""Table-read query handlers — one entry point per ``/data/`` table request. + +The handler owns everything between HTTP parsing and row streaming: table +resolution (404 before bytes), plan construction with the config-clamped page +limit, and delegation to the query service. Routes only translate HTTP +primitives into the Query DTO and wrap the result in a response. + +``__auth__ = public()`` is the seam where the deferred private-schema auth +model will land. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass +from datetime import timedelta +from typing import Any + +from pydantic import Field + +from osa.config import Config +from osa.domain.data.model.filter import FilterExpr +from osa.domain.data.model.manifest import ColumnSpec +from osa.domain.data.model.query_plan import ( + PaginationCursor, + PaginationParams, + QueryPlan, + SortSpec, + TableKind, +) +from osa.domain.data.service.data_catalog import DataCatalogService +from osa.domain.data.service.data_query import DataQueryService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.model.ids import HookName +from osa.domain.shared.query import Query, QueryHandler + + +class ReadRecordsTable(Query): + schema: str # URL segment: ```` or ``@`` + filter: FilterExpr | None = None + cursor: str | None = None + limit: int = 50 + sort: list[SortSpec] = Field(default_factory=list) + timeout: timedelta | None = None # execution budget chosen by the response format + + +class ReadFeatureTable(ReadRecordsTable): + feature: HookName + + +@dataclass +class TableRead: + """A resolved table read: the plan (pagination contract), the column + schema (wire order), and the lazily-evaluated row stream.""" + + plan: QueryPlan + columns: list[ColumnSpec] + rows: AsyncIterator[Mapping[str, Any]] + + +def _pagination(cmd: ReadRecordsTable, config: Config) -> PaginationParams: + return PaginationParams.clamped( + cursor=PaginationCursor(value=cmd.cursor) if cmd.cursor else None, + limit=cmd.limit, + max_limit=config.data.max_page_limit, + ) + + +class ReadRecordsTableHandler(QueryHandler[ReadRecordsTable, TableRead]): + __auth__ = public() + catalog_service: DataCatalogService + query_service: DataQueryService + config: Config + + async def run(self, cmd: ReadRecordsTable) -> TableRead: + table = await self.catalog_service.resolve_table(cmd.schema, TableKind.RECORDS) + plan = QueryPlan( + schema_id=table.schema_id, + table_kind=TableKind.RECORDS, + filter=cmd.filter, + pagination=_pagination(cmd, self.config), + sort=cmd.sort, + ) + rows = self.query_service.stream_records(plan, cmd.timeout) + return TableRead(plan=plan, columns=table.columns, rows=rows) + + +class ReadFeatureTableHandler(QueryHandler[ReadFeatureTable, TableRead]): + __auth__ = public() + catalog_service: DataCatalogService + query_service: DataQueryService + config: Config + + async def run(self, cmd: ReadFeatureTable) -> TableRead: + table = await self.catalog_service.resolve_table( + cmd.schema, TableKind.FEATURE, feature_name=cmd.feature + ) + plan = QueryPlan( + schema_id=table.schema_id, + table_kind=TableKind.FEATURE, + feature_name=cmd.feature, + filter=cmd.filter, + pagination=_pagination(cmd, self.config), + sort=cmd.sort, + ) + rows = self.query_service.stream_features(plan, cmd.timeout) + return TableRead(plan=plan, columns=table.columns, rows=rows) diff --git a/server/osa/domain/data/util/di/provider.py b/server/osa/domain/data/util/di/provider.py index 3230397..7294700 100644 --- a/server/osa/domain/data/util/di/provider.py +++ b/server/osa/domain/data/util/di/provider.py @@ -1,9 +1,18 @@ -"""Dishka DI provider for the data domain (services).""" +"""Dishka DI provider for the data domain (services + query handlers).""" from dishka import provide from osa.config import Config from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.data.query.catalog import ( + GetDataRecordHandler, + GetNodeCatalogHandler, + GetSchemaManifestHandler, +) +from osa.domain.data.query.read_table import ( + ReadFeatureTableHandler, + ReadRecordsTableHandler, +) from osa.domain.data.service.data_catalog import DataCatalogService from osa.domain.data.service.data_query import DataQueryService from osa.util.di.base import Provider @@ -18,3 +27,9 @@ def get_data_query_service(self, read_store: DataReadStore, config: Config) -> D @provide(scope=Scope.UOW) def get_data_catalog_service(self, read_store: DataReadStore) -> DataCatalogService: return DataCatalogService(read_store=read_store) + + read_records_table_handler = provide(ReadRecordsTableHandler, scope=Scope.UOW) + read_feature_table_handler = provide(ReadFeatureTableHandler, scope=Scope.UOW) + get_node_catalog_handler = provide(GetNodeCatalogHandler, scope=Scope.UOW) + get_schema_manifest_handler = provide(GetSchemaManifestHandler, scope=Scope.UOW) + get_data_record_handler = provide(GetDataRecordHandler, scope=Scope.UOW) diff --git a/server/osa/domain/shared/query.py b/server/osa/domain/shared/query.py index e896edb..c04d98f 100644 --- a/server/osa/domain/shared/query.py +++ b/server/osa/domain/shared/query.py @@ -21,7 +21,10 @@ class Result(BaseModel): ... C = TypeVar("C", bound=Query) -R = TypeVar("R", bound=Result) +# Unbound: query results may be Result DTOs, domain read models (e.g. a +# catalog/manifest that IS the wire shape), or streaming reads. Result remains +# the conventional base for handler-specific DTOs. +R = TypeVar("R") # Unbound async handler method: (self, cmd) -> Coroutine -> Result _HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] diff --git a/server/tests/unit/application/api/v1/routes/data/test_params.py b/server/tests/unit/application/api/v1/routes/data/test_params.py index d2380b2..16c5907 100644 --- a/server/tests/unit/application/api/v1/routes/data/test_params.py +++ b/server/tests/unit/application/api/v1/routes/data/test_params.py @@ -1,13 +1,11 @@ -"""Unit tests for table-route request parsing (sort spec + plan build).""" +"""Unit tests for table-route request parsing (sort spec). Plan construction +lives on the table-read handlers (tests/unit/domain/data/test_table_read_handlers.py).""" import pytest -from osa.application.api.v1.routes.data._params import build_plan, parse_sort -from osa.domain.data.model.query_plan import SortDirection, TableKind +from osa.application.api.v1.routes.data._params import parse_sort +from osa.domain.data.model.query_plan import SortDirection from osa.domain.shared.error import ValidationError -from osa.domain.shared.model.srn import SchemaId - -SCHEMA = SchemaId.parse("compound@1.0.0") def test_parse_sort_empty_is_empty_list() -> None: @@ -32,49 +30,3 @@ def test_parse_sort_bare_column_defaults_asc() -> None: def test_parse_sort_invalid_direction_raises() -> None: with pytest.raises(ValidationError): parse_sort("mw:sideways") - - -def test_build_plan_wraps_cursor_and_limit() -> None: - plan = build_plan( - schema_id=SCHEMA, - table_kind=TableKind.RECORDS, - feature_name=None, - filter_expr=None, - cursor="CUR", - limit=10, - max_limit=1000, - sort="mw:asc", - ) - assert plan.pagination.limit == 10 - assert str(plan.pagination.cursor) == "CUR" - assert plan.sort[0].column == "mw" - - -def test_build_plan_clamps_limit_to_configured_max() -> None: - plan = build_plan( - schema_id=SCHEMA, - table_kind=TableKind.RECORDS, - feature_name=None, - filter_expr=None, - cursor=None, - limit=99999, - max_limit=200, - sort=None, - ) - assert plan.pagination.limit == 200 - - -def test_build_plan_no_cursor_is_none() -> None: - plan = build_plan( - schema_id=SCHEMA, - table_kind=TableKind.RECORDS, - feature_name=None, - filter_expr=None, - cursor=None, - limit=50, - max_limit=1000, - sort=None, - ) - assert plan.pagination.cursor is None - # default RECORDS sort applied - assert plan.sort[0].column == "created_at" diff --git a/server/tests/unit/domain/data/test_table_read_handlers.py b/server/tests/unit/domain/data/test_table_read_handlers.py new file mode 100644 index 0000000..9a9da74 --- /dev/null +++ b/server/tests/unit/domain/data/test_table_read_handlers.py @@ -0,0 +1,124 @@ +"""Data-domain query handlers — the route's single entry point per request. + +Restores the documented Router → QueryHandler → Service layering on the /data/ +surface (and with it the __auth__ seam the deferred private-schema auth model +will need). Handlers own plan construction: table resolution, limit clamping +against config, and default sorts — routes only parse HTTP. +""" + +from collections.abc import AsyncIterator, Mapping +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any + +import pytest + +from osa.domain.data.model.manifest import ColumnSpec, ResolvedTable +from osa.domain.data.model.query_plan import ( + QueryPlan, + SortDirection, + SortSpec, + TableKind, +) +from osa.domain.data.query.read_table import ( + ReadFeatureTable, + ReadFeatureTableHandler, + ReadRecordsTable, + ReadRecordsTableHandler, +) +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import SchemaId + +SCHEMA_ID = SchemaId.parse("compound@1.0.0") +COLUMNS = [ColumnSpec(name="id", type=FieldType.TEXT)] + + +@dataclass +class FakeDataConfig: + max_filter_depth: int = 10 + max_predicates: int = 200 + max_feature_joins: int = 10 + max_page_limit: int = 1000 + + +@dataclass +class FakeConfig: + data: FakeDataConfig = field(default_factory=FakeDataConfig) + + +class FakeCatalogService: + def __init__(self) -> None: + self.resolved_with: tuple[Any, ...] | None = None + + async def resolve_table(self, schema, table_kind, feature_name=None) -> ResolvedTable: + self.resolved_with = (schema, table_kind, feature_name) + return ResolvedTable(schema_id=SCHEMA_ID, columns=COLUMNS) + + +class FakeQueryService: + def __init__(self) -> None: + self.received: tuple[QueryPlan, timedelta | None] | None = None + + async def _rows(self) -> AsyncIterator[Mapping[str, Any]]: + yield {"id": "a"} + + def stream_records(self, plan: QueryPlan, timeout: timedelta | None = None): + self.received = (plan, timeout) + return self._rows() + + def stream_features(self, plan: QueryPlan, timeout: timedelta | None = None): + self.received = (plan, timeout) + return self._rows() + + +def _records_handler() -> tuple[ReadRecordsTableHandler, FakeCatalogService, FakeQueryService]: + catalog, query = FakeCatalogService(), FakeQueryService() + handler = ReadRecordsTableHandler( + catalog_service=catalog, query_service=query, config=FakeConfig() + ) + return handler, catalog, query + + +@pytest.mark.asyncio +async def test_records_handler_resolves_and_streams() -> None: + handler, catalog, query = _records_handler() + result = await handler.run(ReadRecordsTable(schema="compound@1.0.0")) + + assert catalog.resolved_with == ("compound@1.0.0", TableKind.RECORDS, None) + assert result.columns == COLUMNS + assert result.plan.schema_id == SCHEMA_ID + assert result.plan.table_kind == TableKind.RECORDS + assert [r async for r in result.rows] == [{"id": "a"}] + + +@pytest.mark.asyncio +async def test_records_handler_clamps_limit_to_config() -> None: + handler, _, _ = _records_handler() + handler.config.data.max_page_limit = 100 + result = await handler.run(ReadRecordsTable(schema="compound@1.0.0", limit=5000)) + assert result.plan.pagination.limit == 100 + + +@pytest.mark.asyncio +async def test_records_handler_forwards_timeout_and_sort() -> None: + handler, _, query = _records_handler() + sort = [SortSpec(column="mw", direction=SortDirection.DESC)] + result = await handler.run( + ReadRecordsTable(schema="compound@1.0.0", sort=sort, timeout=timedelta(seconds=30)) + ) + plan, timeout = query.received + assert plan is result.plan + assert timeout == timedelta(seconds=30) + assert plan.sort[0].column == "mw" + + +@pytest.mark.asyncio +async def test_feature_handler_resolves_feature_and_builds_feature_plan() -> None: + catalog, query = FakeCatalogService(), FakeQueryService() + handler = ReadFeatureTableHandler( + catalog_service=catalog, query_service=query, config=FakeConfig() + ) + result = await handler.run(ReadFeatureTable(schema="compound@1.0.0", feature="chem_features")) + assert catalog.resolved_with == ("compound@1.0.0", TableKind.FEATURE, "chem_features") + assert result.plan.table_kind == TableKind.FEATURE + assert result.plan.feature_name == "chem_features" From 54241140da4e885cdd11d79c6bc51c9bbe2c1e30 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 14:10:31 +0100 Subject: [PATCH 21/23] refactor: split the data read-store god adapter along the service seam MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit postgres_data_read_store.py was a 665-line adapter carrying both the streaming read path and the catalog/manifest/record-by-id reads behind one DataReadStore port. Split into DataTableReadStore (PostgresTableReadStore: streaming, filter SQL, keyset) and DataCatalogReadStore (PostgresCatalogReadStore: catalog, manifest, record lookup), matching the DataQueryService / DataCatalogService split. Schema→feature-table resolution shared by both adapters lives in schema_features.py. --- .../osa/domain/data/port/data_read_store.py | 11 +- .../osa/domain/data/service/data_catalog.py | 6 +- server/osa/domain/data/service/data_query.py | 6 +- server/osa/domain/data/util/di/provider.py | 11 +- .../data/postgres_catalog_read_store.py | 228 +++++++++++++++ ..._store.py => postgres_table_read_store.py} | 268 +----------------- .../infrastructure/data/schema_features.py | 71 +++++ server/osa/infrastructure/persistence/di.py | 20 +- .../test_data_features_postgres.py | 23 +- .../test_data_read_store_postgres.py | 29 +- ...test_data_streaming_guarantees_postgres.py | 8 +- .../data/test_cursor_validation.py | 16 +- .../data/test_statement_timeout.py | 2 +- 13 files changed, 391 insertions(+), 308 deletions(-) create mode 100644 server/osa/infrastructure/data/postgres_catalog_read_store.py rename server/osa/infrastructure/data/{postgres_data_read_store.py => postgres_table_read_store.py} (63%) create mode 100644 server/osa/infrastructure/data/schema_features.py diff --git a/server/osa/domain/data/port/data_read_store.py b/server/osa/domain/data/port/data_read_store.py index d82ef3a..dc77068 100644 --- a/server/osa/domain/data/port/data_read_store.py +++ b/server/osa/domain/data/port/data_read_store.py @@ -1,12 +1,15 @@ -"""DataReadStore port — read-only access feeding every ``/data/`` format. +"""Read-store ports feeding the ``/data/`` surface — split along the service seam. -``stream_rows`` is the single streaming primitive behind both the paginated +``DataTableReadStore`` is the streaming primitive behind both the paginated JSON path and the unbounded CSV/CSV.gz path; the concrete adapter backs it with a Postgres server-side cursor (research §2) so memory stays bounded regardless of result size. Rows are yielded as column→value mappings (already projected): records-table rows include the implicit ``id``/``srn``/``schema_id``/ ``version``/``created_at`` columns plus metadata fields; feature-table rows carry the hook's declared columns. + +``DataCatalogReadStore`` serves the non-streaming reads: node catalog, schema +manifest, latest-schema resolution, and single-record-by-id. """ from __future__ import annotations @@ -24,7 +27,7 @@ from osa.domain.shared.model.srn import SchemaId -class DataReadStore(Protocol): +class DataTableReadStore(Protocol): def stream_rows( self, plan: "QueryPlan", timeout: timedelta | None = None ) -> AsyncIterator[Mapping[str, Any]]: @@ -35,6 +38,8 @@ def stream_rows( """ ... + +class DataCatalogReadStore(Protocol): async def get_record_by_id(self, id: "RecordId", version: int | None) -> "RecordSummary | None": """Resolve a single record by bare ID (schema resolved via PK). ``None`` if absent.""" ... diff --git a/server/osa/domain/data/service/data_catalog.py b/server/osa/domain/data/service/data_catalog.py index 2c4abe8..f461b33 100644 --- a/server/osa/domain/data/service/data_catalog.py +++ b/server/osa/domain/data/service/data_catalog.py @@ -1,6 +1,6 @@ """DataCatalogService — catalog, manifest, and single-record-by-ID reads. -Read-only orchestration over the :class:`DataReadStore` port. Catalog + manifest +Read-only orchestration over the :class:`DataCatalogReadStore` port. Catalog + manifest (US3) and record-by-ID (US4). The reserved-name 404 is enforced defensively here in addition to the write-side aggregate invariants (research §6). """ @@ -11,7 +11,7 @@ from osa.domain.data.model.manifest import ResolvedTable, SchemaManifest from osa.domain.data.model.query_plan import TableKind from osa.domain.data.model.record_summary import RecordSummary -from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.data.port.data_read_store import DataCatalogReadStore from osa.domain.shared.error import NotFoundError from osa.domain.shared.model.ids import HookName, RecordId from osa.domain.shared.model.reserved import RESERVED_NAMES @@ -20,7 +20,7 @@ class DataCatalogService(Service): - read_store: DataReadStore + read_store: DataCatalogReadStore async def resolve_schema(self, raw: str) -> SchemaId: """Resolve a URL schema segment (```` or ``@``) to a SchemaId. diff --git a/server/osa/domain/data/service/data_query.py b/server/osa/domain/data/service/data_query.py index 9437c67..0f595d8 100644 --- a/server/osa/domain/data/service/data_query.py +++ b/server/osa/domain/data/service/data_query.py @@ -2,7 +2,7 @@ Validates the filter tree's bounds (depth, predicate count, distinct feature joins) against config, then delegates row streaming to the -:class:`DataReadStore` port. The returned async iterators are wrapped by the +:class:`DataTableReadStore` port. The returned async iterators are wrapped by the route layer in a serializer + ``StreamingResponse``. Validation runs at the top of the generator body, so it surfaces on the route's pre-flight ``__anext__`` pull — before any response bytes (research §4). @@ -21,13 +21,13 @@ from osa.config import Config from osa.domain.data.model.filter import And, FeatureFieldRef, FilterExpr, Not, Or, Predicate from osa.domain.data.model.query_plan import QueryPlan, TableKind -from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.data.port.data_read_store import DataTableReadStore from osa.domain.shared.error import ValidationError from osa.domain.shared.service import Service class DataQueryService(Service): - read_store: DataReadStore + read_store: DataTableReadStore config: Config async def stream_records( diff --git a/server/osa/domain/data/util/di/provider.py b/server/osa/domain/data/util/di/provider.py index 7294700..1c5d236 100644 --- a/server/osa/domain/data/util/di/provider.py +++ b/server/osa/domain/data/util/di/provider.py @@ -3,7 +3,10 @@ from dishka import provide from osa.config import Config -from osa.domain.data.port.data_read_store import DataReadStore +from osa.domain.data.port.data_read_store import ( + DataCatalogReadStore, + DataTableReadStore, +) from osa.domain.data.query.catalog import ( GetDataRecordHandler, GetNodeCatalogHandler, @@ -21,11 +24,13 @@ class DataProvider(Provider): @provide(scope=Scope.UOW) - def get_data_query_service(self, read_store: DataReadStore, config: Config) -> DataQueryService: + def get_data_query_service( + self, read_store: DataTableReadStore, config: Config + ) -> DataQueryService: return DataQueryService(read_store=read_store, config=config) @provide(scope=Scope.UOW) - def get_data_catalog_service(self, read_store: DataReadStore) -> DataCatalogService: + def get_data_catalog_service(self, read_store: DataCatalogReadStore) -> DataCatalogService: return DataCatalogService(read_store=read_store) read_records_table_handler = provide(ReadRecordsTableHandler, scope=Scope.UOW) diff --git a/server/osa/infrastructure/data/postgres_catalog_read_store.py b/server/osa/infrastructure/data/postgres_catalog_read_store.py new file mode 100644 index 0000000..a396dc6 --- /dev/null +++ b/server/osa/infrastructure/data/postgres_catalog_read_store.py @@ -0,0 +1,228 @@ +"""Postgres adapter for the ``DataCatalogReadStore`` port. + +Catalog, manifest, latest-schema resolution, and single-record-by-id — the +non-streaming reads behind ``GET /data``, ``GET /data/{schema}``, and +``GET /data/records/{id}``. Table streaming lives in +:class:`~osa.infrastructure.data.postgres_table_read_store.PostgresTableReadStore`. +""" + +from __future__ import annotations + +import logging + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.data.model.catalog import ( + CatalogEntry, + NodeCatalog, + TableResourceSummary, +) +from osa.domain.data.model.manifest import ( + IMPLICIT_FEATURE_COLUMN_SPECS, + IMPLICIT_RECORD_COLUMN_SPECS, + ColumnSpec, + FieldSpec, + SchemaManifest, + TableResource, +) +from osa.domain.data.model.query_plan import TableKind +from osa.domain.data.model.record_summary import RecordSummary +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.ids import RecordId +from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.infrastructure.data.schema_features import ( + feature_count_stmt, + schema_feature_tables, +) +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, +) +from osa.infrastructure.persistence.tables import records_table, schemas_table + +logger = logging.getLogger(__name__) + +# A feature column's JSON-primitive type → the manifest's semantic FieldType. +_JSON_TYPE_TO_FIELD_TYPE: dict[str, FieldType] = { + "string": FieldType.TEXT, + "number": FieldType.NUMBER, + "integer": FieldType.NUMBER, + "boolean": FieldType.BOOLEAN, + "array": FieldType.TEXT, + "object": FieldType.TEXT, +} + +# All URL-exposed format suffixes (mirrors the route-layer FORMATS registry). +_ALL_FORMATS = ["", "csv", "csv.gz"] + + +def _escape_like(value: str) -> str: + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + +class PostgresCatalogReadStore: + def __init__(self, session: AsyncSession, node_domain: Domain) -> None: + self.session = session + # Only the node's DNS domain is needed (to render SRNs in the catalog / + # manifest) — not the whole Config. + self.node_domain = node_domain + + # ------------------------------------------------------------------ # + # Single record by ID + # ------------------------------------------------------------------ # + + async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: + # The records PK is the SRN ``urn:osa:{domain}:rec:{id}@{version}``. + # Match the id segment; resolve version (pin or latest published). + pattern = f"urn:osa:%:rec:{_escape_like(str(id))}@%" + t = records_table + stmt = ( + select(t.c.srn, t.c.schema_id, t.c.schema_version, t.c.published_at, t.c.metadata) + .where(t.c.srn.like(pattern, escape="\\")) + .order_by(t.c.published_at.desc()) + ) + result = await self.session.execute(stmt) + rows = result.mappings().all() + if not rows: + return None + + chosen = None + for row in rows: + srn = RecordSRN.parse(row["srn"]) + if srn.id.root != str(id): + continue + if version is None: + chosen = (srn, row) + break + if int(srn.version.root) == version: + chosen = (srn, row) + break + if chosen is None: + return None + srn, row = chosen + return RecordSummary( + id=RecordId(srn.id.root), + srn=srn, + schema_id=SchemaId.parse(f"{row['schema_id']}@{row['schema_version']}"), + version=int(srn.version.root), + metadata=row["metadata"] or {}, + created_at=row["published_at"], + ) + + # ------------------------------------------------------------------ # + # Catalog & manifest + # ------------------------------------------------------------------ # + + async def get_node_catalog(self) -> NodeCatalog: + stmt = select(schemas_table.c.id, schemas_table.c.version) + result = await self.session.execute(stmt) + schema_rows = [(row["id"], row["version"]) for row in result.mappings()] + entries: list[CatalogEntry] = [] + for short_id, version in schema_rows: + schema_id = SchemaId.parse(f"{short_id}@{version}") + resources = [TableResourceSummary(name="records", kind=TableKind.RECORDS)] + for hook_name, _ in await schema_feature_tables(self.session, schema_id): + resources.append(TableResourceSummary(name=hook_name, kind=TableKind.FEATURE)) + entries.append( + CatalogEntry( + id=short_id, + version=version, + srn=schema_id.to_srn(self.node_domain).render(), + table_resources=resources, + ) + ) + return NodeCatalog(node_domain=self.node_domain.root, schemas=entries) + + async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | None: + stmt = select(schemas_table.c.fields).where( + schemas_table.c.id == schema_id.id.root, + schemas_table.c.version == schema_id.version.root, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + + field_specs: list[FieldSpec] = [] + column_specs: list[ColumnSpec] = [] + for f in row["fields"]: + ftype = FieldType(f["type"]) + field_specs.append( + FieldSpec( + name=f["name"], + type=ftype, + ontology_id=f.get("ontology_id"), + ontology_version=f.get("ontology_version"), + ) + ) + column_specs.append(ColumnSpec(name=f["name"], type=ftype)) + + record_count = await self._records_count(schema_id) + records_resource = TableResource( + name="records", + kind=TableKind.RECORDS, + # Implicit columns (id, srn, schema_id, version, created_at) precede + # the schema's declared metadata fields — this is the CSV header order. + columns=[*IMPLICIT_RECORD_COLUMN_SPECS, *column_specs], + row_count=record_count, + formats=list(_ALL_FORMATS), + ) + feature_resources = await self._feature_resources(schema_id) + return SchemaManifest( + id=schema_id.id.root, + version=schema_id.version.root, + srn=schema_id.to_srn(self.node_domain).render(), + fields=field_specs, + table_resources=[records_resource, *feature_resources], + ) + + async def _feature_resources(self, schema_id: SchemaId) -> list[TableResource]: + """Build a TableResource for each feature table registered on the schema.""" + resources: list[TableResource] = [] + for hook_name, fschema in await schema_feature_tables(self.session, schema_id): + ft = build_feature_table(hook_name, fschema) + count = int( + (await self.session.execute(feature_count_stmt(ft, schema_id))).scalar_one() + ) + resources.append( + TableResource( + name=hook_name, + kind=TableKind.FEATURE, + # Implicit columns (id, record_srn, created_at) precede the + # hook's declared data columns — this is the CSV header order. + columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *_feature_column_specs(fschema)], + row_count=count, + formats=list(_ALL_FORMATS), + ) + ) + return resources + + async def get_latest_schema_id(self, schema_short_id: str) -> SchemaId | None: + stmt = select(schemas_table.c.version).where(schemas_table.c.id == schema_short_id) + result = await self.session.execute(stmt) + versions = [row[0] for row in result.all()] + if not versions: + return None + # Pick the highest SemVer (string sort is wrong for e.g. 1.10.0 vs 1.9.0). + latest = max(versions, key=lambda v: tuple(int(p) for p in v.split("-")[0].split("."))) + return SchemaId.parse(f"{schema_short_id}@{latest}") + + async def _records_count(self, schema_id: SchemaId) -> int: + t = records_table + stmt = ( + select(func.count()) + .select_from(t) + .where( + t.c.schema_id == schema_id.id.root, + t.c.schema_version == schema_id.version.root, + ) + ) + return int((await self.session.execute(stmt)).scalar_one()) + + +def _feature_column_specs(fschema: FeatureSchema) -> list[ColumnSpec]: + """Map a feature table's declared columns to manifest ColumnSpecs.""" + return [ + ColumnSpec(name=c.name, type=_JSON_TYPE_TO_FIELD_TYPE[c.json_type]) for c in fschema.columns + ] diff --git a/server/osa/infrastructure/data/postgres_data_read_store.py b/server/osa/infrastructure/data/postgres_table_read_store.py similarity index 63% rename from server/osa/infrastructure/data/postgres_data_read_store.py rename to server/osa/infrastructure/data/postgres_table_read_store.py index a336e26..aa177de 100644 --- a/server/osa/infrastructure/data/postgres_data_read_store.py +++ b/server/osa/infrastructure/data/postgres_table_read_store.py @@ -1,15 +1,14 @@ -"""Postgres adapter for the ``DataReadStore`` port. +"""Postgres adapter for the ``DataTableReadStore`` port (streaming reads). The streaming primitive (:meth:`stream_rows`) builds the records / feature SELECT, then iterates it through ``AsyncSession.stream()`` — a server-side cursor (research §2). Rows are yielded one at a time as flattened column→value -mappings, so memory stays bounded regardless of result size. The ``async with`` +mappings, so memory stays bounded regardless of result size. The try/finally around the streaming result closes the cursor on client disconnect, returning the connection to the pool. -The records filter compilation reuses the proven approach from the discovery -adapter, ported onto the ``data`` domain's own filter model so the adapter has -no dependency on ``discovery`` once that domain is deleted (research §10). +Catalog/manifest/record-by-id reads live in +:class:`~osa.infrastructure.data.postgres_catalog_read_store.PostgresCatalogReadStore`. """ from __future__ import annotations @@ -20,24 +19,9 @@ from typing import Any import sqlalchemy as sa -from sqlalchemy import ( - RowMapping, - String, - and_, - cast, - false, - func, - not_, - or_, - select, -) +from sqlalchemy import RowMapping, String, and_, cast, false, func, not_, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from osa.domain.data.model.catalog import ( - CatalogEntry, - NodeCatalog, - TableResourceSummary, -) from osa.domain.data.model.filter import ( And, FeatureFieldRef, @@ -48,14 +32,6 @@ Or, Predicate, ) -from osa.domain.data.model.manifest import ( - IMPLICIT_FEATURE_COLUMN_SPECS, - IMPLICIT_RECORD_COLUMN_SPECS, - ColumnSpec, - FieldSpec, - SchemaManifest, - TableResource, -) from osa.domain.data.model.query_plan import ( QueryPlan, SortDirection, @@ -63,10 +39,13 @@ decode_cursor, ) from osa.domain.data.model.record_summary import RecordSummary -from osa.domain.semantics.model.value import FieldType from osa.domain.shared.error import NotFoundError, ValidationError from osa.domain.shared.model.ids import RecordId -from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId +from osa.domain.shared.model.srn import RecordSRN, SchemaId +from osa.infrastructure.data.schema_features import ( + feature_schema_scope, + schema_feature_tables, +) from osa.infrastructure.persistence.feature_table import ( FeatureSchema, build_feature_table, @@ -78,28 +57,12 @@ build_metadata_table, ) from osa.infrastructure.persistence.tables import ( - conventions_table, - feature_tables_table, metadata_tables_table, records_table, - schemas_table, ) logger = logging.getLogger(__name__) -# A feature column's JSON-primitive type → the manifest's semantic FieldType. -_JSON_TYPE_TO_FIELD_TYPE: dict[str, FieldType] = { - "string": FieldType.TEXT, - "number": FieldType.NUMBER, - "integer": FieldType.NUMBER, - "boolean": FieldType.BOOLEAN, - "array": FieldType.TEXT, - "object": FieldType.TEXT, -} - -# All URL-exposed format suffixes (mirrors the route-layer FORMATS registry). -_ALL_FORMATS = ["", "csv", "csv.gz"] - def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") @@ -128,12 +91,9 @@ def _invalid_cursor(exc: Exception) -> ValidationError: ) -class PostgresDataReadStore: - def __init__(self, session: AsyncSession, node_domain: Domain) -> None: +class PostgresTableReadStore: + def __init__(self, session: AsyncSession) -> None: self.session = session - # Only the node's DNS domain is needed (to render SRNs in the catalog / - # manifest) — not the whole Config. - self.node_domain = node_domain # ------------------------------------------------------------------ # # Streaming primitive @@ -306,7 +266,7 @@ async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, # A features. table is shared by every convention that registers # the hook name, across schemas — scope to the requested schema's # records via the records join. - conditions.extend(self._feature_schema_scope(plan.schema_id)) + conditions.extend(feature_schema_scope(plan.schema_id)) # Implicit columns (id, record_srn, created_at) precede the hook's # declared data columns — this is the CSV header order. @@ -334,7 +294,7 @@ async def _resolve_feature_table( it. Streaming a feature not registered on the schema is a 404, not a leak of another schema's table. """ - for hook_name, fschema in await self._schema_feature_tables(schema_id): + for hook_name, fschema in await schema_feature_tables(self.session, schema_id): if hook_name == feature_name: return build_feature_table(hook_name, fschema), fschema raise NotFoundError( @@ -343,48 +303,6 @@ async def _resolve_feature_table( code="feature_not_found", ) - async def _schema_hook_names(self, schema_id: SchemaId) -> set[str]: - """Hook names registered on the schema's conventions (the schema→feature link).""" - stmt = select(conventions_table.c.hooks).where( - conventions_table.c.schema_id == schema_id.id.root, - conventions_table.c.schema_version == schema_id.version.root, - ) - result = await self.session.execute(stmt) - names: set[str] = set() - for (hooks,) in result.all(): - for hook in hooks or []: - names.add(hook["name"]) - return names - - async def _schema_feature_tables(self, schema_id: SchemaId) -> list[tuple[str, FeatureSchema]]: - """(hook_name, FeatureSchema) for every materialized feature table on the schema.""" - hook_names = await self._schema_hook_names(schema_id) - if not hook_names: - return [] - stmt = select( - feature_tables_table.c.hook_name, feature_tables_table.c.feature_schema - ).where(feature_tables_table.c.hook_name.in_(hook_names)) - result = await self.session.execute(stmt) - return [ - (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) - for row in result.mappings() - ] - - @staticmethod - def _feature_schema_scope(schema_id: SchemaId) -> list[Any]: - return [ - records_table.c.schema_id == schema_id.id.root, - records_table.c.schema_version == schema_id.version.root, - ] - - async def _feature_count(self, ft: sa.Table, schema_id: SchemaId) -> int: - stmt = ( - select(func.count()) - .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) - .where(and_(*self._feature_schema_scope(schema_id))) - ) - return int((await self.session.execute(stmt)).scalar_one()) - def _features_sort(self, plan: QueryPlan, ft: sa.Table) -> tuple[list[Any], Any | None]: # plan.keyset owns tiebreak choice and sort=id aliasing; this method # only maps its column names onto SQL expressions. @@ -458,156 +376,7 @@ def _compile_feature_predicate( raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") # ------------------------------------------------------------------ # - # Single record by ID - # ------------------------------------------------------------------ # - - async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: - # The records PK is the SRN ``urn:osa:{domain}:rec:{id}@{version}``. - # Match the id segment; resolve version (pin or latest published). - pattern = f"urn:osa:%:rec:{_escape_like(str(id))}@%" - t = records_table - stmt = ( - select(t.c.srn, t.c.schema_id, t.c.schema_version, t.c.published_at, t.c.metadata) - .where(t.c.srn.like(pattern, escape="\\")) - .order_by(t.c.published_at.desc()) - ) - result = await self.session.execute(stmt) - rows = result.mappings().all() - if not rows: - return None - - chosen = None - for row in rows: - srn = RecordSRN.parse(row["srn"]) - if srn.id.root != str(id): - continue - if version is None: - chosen = (srn, row) - break - if int(srn.version.root) == version: - chosen = (srn, row) - break - if chosen is None: - return None - srn, row = chosen - return RecordSummary( - id=RecordId(srn.id.root), - srn=srn, - schema_id=SchemaId.parse(f"{row['schema_id']}@{row['schema_version']}"), - version=int(srn.version.root), - metadata=row["metadata"] or {}, - created_at=row["published_at"], - ) - - # ------------------------------------------------------------------ # - # Catalog & manifest - # ------------------------------------------------------------------ # - - async def get_node_catalog(self) -> NodeCatalog: - stmt = select(schemas_table.c.id, schemas_table.c.version) - result = await self.session.execute(stmt) - schema_rows = [(row["id"], row["version"]) for row in result.mappings()] - entries: list[CatalogEntry] = [] - for short_id, version in schema_rows: - schema_id = SchemaId.parse(f"{short_id}@{version}") - resources = [TableResourceSummary(name="records", kind=TableKind.RECORDS)] - for hook_name, _ in await self._schema_feature_tables(schema_id): - resources.append(TableResourceSummary(name=hook_name, kind=TableKind.FEATURE)) - entries.append( - CatalogEntry( - id=short_id, - version=version, - srn=schema_id.to_srn(self.node_domain).render(), - table_resources=resources, - ) - ) - return NodeCatalog(node_domain=self.node_domain.root, schemas=entries) - - async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | None: - stmt = select(schemas_table.c.fields).where( - schemas_table.c.id == schema_id.id.root, - schemas_table.c.version == schema_id.version.root, - ) - result = await self.session.execute(stmt) - row = result.mappings().first() - if row is None: - return None - - field_specs: list[FieldSpec] = [] - column_specs: list[ColumnSpec] = [] - for f in row["fields"]: - ftype = FieldType(f["type"]) - field_specs.append( - FieldSpec( - name=f["name"], - type=ftype, - ontology_id=f.get("ontology_id"), - ontology_version=f.get("ontology_version"), - ) - ) - column_specs.append(ColumnSpec(name=f["name"], type=ftype)) - - record_count = await self._records_count(schema_id) - records_resource = TableResource( - name="records", - kind=TableKind.RECORDS, - # Implicit columns (id, srn, schema_id, version, created_at) precede - # the schema's declared metadata fields — this is the CSV header order. - columns=[*IMPLICIT_RECORD_COLUMN_SPECS, *column_specs], - row_count=record_count, - formats=list(_ALL_FORMATS), - ) - feature_resources = await self._feature_resources(schema_id) - return SchemaManifest( - id=schema_id.id.root, - version=schema_id.version.root, - srn=schema_id.to_srn(self.node_domain).render(), - fields=field_specs, - table_resources=[records_resource, *feature_resources], - ) - - async def _feature_resources(self, schema_id: SchemaId) -> list[TableResource]: - """Build a TableResource for each feature table registered on the schema.""" - resources: list[TableResource] = [] - for hook_name, fschema in await self._schema_feature_tables(schema_id): - ft = build_feature_table(hook_name, fschema) - resources.append( - TableResource( - name=hook_name, - kind=TableKind.FEATURE, - # Implicit columns (id, record_srn, created_at) precede the - # hook's declared data columns — this is the CSV header order. - columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *_feature_column_specs(fschema)], - row_count=await self._feature_count(ft, schema_id), - formats=list(_ALL_FORMATS), - ) - ) - return resources - - async def get_latest_schema_id(self, schema_short_id: str) -> SchemaId | None: - stmt = select(schemas_table.c.version).where(schemas_table.c.id == schema_short_id) - result = await self.session.execute(stmt) - versions = [row[0] for row in result.all()] - if not versions: - return None - # Pick the highest SemVer (string sort is wrong for e.g. 1.10.0 vs 1.9.0). - latest = max(versions, key=lambda v: tuple(int(p) for p in v.split("-")[0].split("."))) - return SchemaId.parse(f"{schema_short_id}@{latest}") - - async def _records_count(self, schema_id: SchemaId) -> int: - t = records_table - stmt = ( - select(func.count()) - .select_from(t) - .where( - t.c.schema_id == schema_id.id.root, - t.c.schema_version == schema_id.version.root, - ) - ) - return int((await self.session.execute(stmt)).scalar_one()) - - # ------------------------------------------------------------------ # - # Helpers (ported from the discovery adapter, records-only) + # Records filter compilation # ------------------------------------------------------------------ # async def _metadata_catalog_for(self, schema_id: SchemaId) -> dict[str, Any] | None: @@ -653,13 +422,6 @@ def _compile_predicate(self, predicate: Predicate, *, metadata_t: Any) -> Any: raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") -def _feature_column_specs(fschema: FeatureSchema) -> list[ColumnSpec]: - """Map a feature table's declared columns to manifest ColumnSpecs.""" - return [ - ColumnSpec(name=c.name, type=_JSON_TYPE_TO_FIELD_TYPE[c.json_type]) for c in fschema.columns - ] - - def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: if op == FilterOperator.EQ: return col == value diff --git a/server/osa/infrastructure/data/schema_features.py b/server/osa/infrastructure/data/schema_features.py new file mode 100644 index 0000000..586402e --- /dev/null +++ b/server/osa/infrastructure/data/schema_features.py @@ -0,0 +1,71 @@ +"""Shared schema→feature-table resolution used by both /data/ read adapters. + +A ``features.`` table is global (UNIQUE(hook_name)) and shared by every +convention that registers the hook name, across schemas. These helpers answer +"which feature tables belong to this schema" (via its conventions) and provide +the records-join condition that scopes feature rows to one schema's records. +""" + +from __future__ import annotations + +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.persistence.feature_table import FeatureSchema +from osa.infrastructure.persistence.tables import ( + conventions_table, + feature_tables_table, + records_table, +) + + +async def schema_hook_names(session: AsyncSession, schema_id: SchemaId) -> set[str]: + """Hook names registered on the schema's conventions (the schema→feature link).""" + stmt = select(conventions_table.c.hooks).where( + conventions_table.c.schema_id == schema_id.id.root, + conventions_table.c.schema_version == schema_id.version.root, + ) + result = await session.execute(stmt) + names: set[str] = set() + for (hooks,) in result.all(): + for hook in hooks or []: + names.add(hook["name"]) + return names + + +async def schema_feature_tables( + session: AsyncSession, schema_id: SchemaId +) -> list[tuple[str, FeatureSchema]]: + """(hook_name, FeatureSchema) for every materialized feature table on the schema.""" + hook_names = await schema_hook_names(session, schema_id) + if not hook_names: + return [] + stmt = select(feature_tables_table.c.hook_name, feature_tables_table.c.feature_schema).where( + feature_tables_table.c.hook_name.in_(hook_names) + ) + result = await session.execute(stmt) + return [ + (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) + for row in result.mappings() + ] + + +def feature_schema_scope(schema_id: SchemaId) -> list[Any]: + """Records-join conditions scoping shared feature tables to one schema.""" + return [ + records_table.c.schema_id == schema_id.id.root, + records_table.c.schema_version == schema_id.version.root, + ] + + +def feature_count_stmt(ft: sa.Table, schema_id: SchemaId) -> sa.Select[Any]: + """COUNT of a feature table's rows scoped to the schema's records.""" + return ( + select(sa.func.count()) + .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) + .where(sa.and_(*feature_schema_scope(schema_id))) + ) diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 60fdc19..1c8eacd 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -28,8 +28,12 @@ from osa.domain.shared.port.event_repository import EventRepository from osa.domain.feature.port.feature_store import FeatureStore from osa.domain.validation.port.repository import ValidationRunRepository -from osa.domain.data.port.data_read_store import DataReadStore -from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.domain.data.port.data_read_store import ( + DataCatalogReadStore, + DataTableReadStore, +) +from osa.infrastructure.data.postgres_catalog_read_store import PostgresCatalogReadStore +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore from osa.infrastructure.persistence.adapter.readers import ( OntologyReaderAdapter, SchemaReaderAdapter, @@ -171,9 +175,15 @@ def get_record_service( ) # Data read surface adapter (unified /data/* engine) - @provide(scope=Scope.UOW, provides=DataReadStore) - def get_data_read_store(self, session: AsyncSession, config: Config) -> PostgresDataReadStore: - return PostgresDataReadStore(session=session, node_domain=Domain(config.domain)) + @provide(scope=Scope.UOW, provides=DataTableReadStore) + def get_data_table_read_store(self, session: AsyncSession) -> PostgresTableReadStore: + return PostgresTableReadStore(session=session) + + @provide(scope=Scope.UOW, provides=DataCatalogReadStore) + def get_data_catalog_read_store( + self, session: AsyncSession, config: Config + ) -> PostgresCatalogReadStore: + return PostgresCatalogReadStore(session=session, node_domain=Domain(config.domain)) # Record query handlers get_record_handler = provide(GetRecordHandler, scope=Scope.UOW) diff --git a/server/tests/integration/test_data_features_postgres.py b/server/tests/integration/test_data_features_postgres.py index 8cca997..cc01784 100644 --- a/server/tests/integration/test_data_features_postgres.py +++ b/server/tests/integration/test_data_features_postgres.py @@ -1,4 +1,4 @@ -"""Integration tests for PostgresDataReadStore feature-table streaming (US5). +"""Integration tests for PostgresTableReadStore feature-table streaming (US5). Exercises the FEATURE branch of the unified ``/data/`` engine against a real Postgres: a schema with a registered hook produces a ``features.`` table; @@ -29,7 +29,8 @@ from osa.domain.shared.error import ConflictError, NotFoundError from osa.domain.shared.model.hook import ColumnDef from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId -from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.infrastructure.data.postgres_catalog_read_store import PostgresCatalogReadStore +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore from osa.infrastructure.persistence.feature_store import PostgresFeatureStore from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore from osa.infrastructure.persistence.repository.schema import ( @@ -133,7 +134,7 @@ def _feature_plan(filter_expr=None, limit=50) -> QueryPlan: ) -async def _drain(store: PostgresDataReadStore, plan: QueryPlan) -> list[dict]: +async def _drain(store: PostgresTableReadStore, plan: QueryPlan) -> list[dict]: return [dict(row) async for row in store.stream_rows(plan)] @@ -151,7 +152,7 @@ async def test_streams_feature_rows_with_data_columns( HOOK, str(srn), [{"score": 0.9, "label": "high"}, {"score": 0.1, "label": "low"}] ) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) rows = await _drain(rs, _feature_plan()) assert len(rows) == 2 @@ -174,7 +175,7 @@ async def test_feature_filter_narrows_results( HOOK, str(srn), [{"score": 0.9, "label": "high"}, {"score": 0.1, "label": "low"}] ) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) plan = _feature_plan( filter_expr=Predicate( field=FeatureFieldRef(hook=HOOK, column="score"), @@ -190,7 +191,7 @@ async def test_unknown_feature_raises_not_found( ): await _setup_schema(pg_engine, pg_session) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) plan = QueryPlan( schema_id=SCHEMA, table_kind=TableKind.FEATURE, @@ -220,7 +221,7 @@ async def test_created_at_sort_cursor_round_trips( feature_store = PostgresFeatureStore(pg_engine, pg_session) await feature_store.insert_features(HOOK, str(srn), [{"score": s} for s in (0.1, 0.2, 0.3)]) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) sort = [SortSpec(column="created_at", direction=SortDirection.ASC)] all_rows = await _drain( rs, @@ -271,7 +272,7 @@ async def test_stream_excludes_rows_of_other_schemas( ): srn_a, _ = await self._seed_shared_hook(pg_engine, pg_session) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) rows = await _drain(rs, _feature_plan()) assert [r["record_srn"] for r in rows] == [str(srn_a)] @@ -281,7 +282,7 @@ async def test_manifest_row_count_scoped_to_schema( ): await self._seed_shared_hook(pg_engine, pg_session) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) manifest = await rs.get_schema_manifest(SCHEMA) assert manifest is not None feature_res = next(t for t in manifest.table_resources if t.name == HOOK) @@ -300,7 +301,7 @@ async def test_manifest_includes_feature_resource( feature_store = PostgresFeatureStore(pg_engine, pg_session) await feature_store.insert_features(HOOK, str(srn), [{"score": 0.9, "label": "high"}]) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) manifest = await rs.get_schema_manifest(SCHEMA) assert manifest is not None feature_res = next(t for t in manifest.table_resources if t.name == HOOK) @@ -318,7 +319,7 @@ async def test_catalog_lists_feature_resource( await pg_session.commit() await _register_hook(pg_engine, pg_session) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) catalog = await rs.get_node_catalog() entry = next(e for e in catalog.schemas if e.id == "compound") names = {(t.name, t.kind) for t in entry.table_resources} diff --git a/server/tests/integration/test_data_read_store_postgres.py b/server/tests/integration/test_data_read_store_postgres.py index ebb1405..24bc706 100644 --- a/server/tests/integration/test_data_read_store_postgres.py +++ b/server/tests/integration/test_data_read_store_postgres.py @@ -1,4 +1,4 @@ -"""Integration tests for PostgresDataReadStore against a real Postgres. +"""Integration tests for PostgresTableReadStore against a real Postgres. Exercises the unified /data/ engine end-to-end: records streaming via the server-side cursor, filtered streaming, single-record-by-id, node catalog, and @@ -28,7 +28,8 @@ from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType from osa.domain.shared.model.ids import RecordId from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId -from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.infrastructure.data.postgres_catalog_read_store import PostgresCatalogReadStore +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore from osa.infrastructure.persistence.repository.schema import ( PostgresSemanticsSchemaRepository, @@ -95,7 +96,7 @@ def _records_plan(filter_expr=None, limit=50, cursor=None) -> QueryPlan: ) -async def _drain(store: PostgresDataReadStore, plan: QueryPlan) -> list[dict]: +async def _drain(store: PostgresTableReadStore, plan: QueryPlan) -> list[dict]: return [dict(row) async for row in store.stream_rows(plan)] @@ -113,7 +114,7 @@ async def test_streams_all_rows_flattened( ) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) rows = await _drain(rs, _records_plan()) assert len(rows) == 2 @@ -139,7 +140,7 @@ async def test_metadata_filter_narrows_results( ) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) plan = _records_plan( filter_expr=Predicate( field=MetadataFieldRef(field="species"), @@ -155,7 +156,7 @@ async def test_empty_result_streams_nothing( ): await _setup_schema(pg_engine, pg_session) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) assert await _drain(rs, _records_plan()) == [] @@ -170,7 +171,7 @@ async def test_resolves_by_bare_id_with_srn( ) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) rec = await rs.get_record_by_id(RecordId("abc"), None) assert rec is not None assert str(rec.id) == "abc" @@ -182,7 +183,7 @@ async def test_resolves_by_bare_id_with_srn( async def test_unknown_id_returns_none(self, pg_engine: AsyncEngine, pg_session: AsyncSession): await _setup_schema(pg_engine, pg_session) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) assert await rs.get_record_by_id(RecordId("missing"), None) is None @@ -193,7 +194,7 @@ async def test_node_catalog_lists_schema( ): await _setup_schema(pg_engine, pg_session) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) catalog = await rs.get_node_catalog() assert catalog.node_domain == "localhost" ids = {(e.id, e.version) for e in catalog.schemas} @@ -206,7 +207,7 @@ async def test_manifest_shape(self, pg_engine: AsyncEngine, pg_session: AsyncSes ) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) manifest = await rs.get_schema_manifest(SCHEMA) assert manifest is not None assert manifest.id == "compound" @@ -224,7 +225,7 @@ async def test_manifest_shape(self, pg_engine: AsyncEngine, pg_session: AsyncSes async def test_get_latest_schema_id(self, pg_engine: AsyncEngine, pg_session: AsyncSession): await _setup_schema(pg_engine, pg_session) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresCatalogReadStore(pg_session, Domain("localhost")) resolved = await rs.get_latest_schema_id("compound") assert resolved is not None assert resolved.render() == "compound@1.0.0" @@ -247,7 +248,7 @@ async def test_cursor_advances_without_overlap( datetime(2026, 1, 1 + i, tzinfo=UTC), ) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) # Sort by created_at desc (default). Page 1: take 2, derive cursor from row 2. all_rows = await _drain(rs, _records_plan()) @@ -281,7 +282,7 @@ async def test_id_sort_cursor_round_trips( datetime(2026, 1, 1, tzinfo=UTC), ) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) sort = [SortSpec(column="id", direction=SortDirection.ASC)] all_rows = await _drain( @@ -342,7 +343,7 @@ async def test_date_sort_cursor_round_trips( await store.insert(ASSAY_SCHEMA, srn, {"assay_date": f"2026-01-0{i + 1}"}) await pg_session.commit() - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) sort = [SortSpec(column="assay_date", direction=SortDirection.ASC)] all_rows = await _drain( rs, QueryPlan(schema_id=ASSAY_SCHEMA, table_kind=TableKind.RECORDS, sort=sort) diff --git a/server/tests/integration/test_data_streaming_guarantees_postgres.py b/server/tests/integration/test_data_streaming_guarantees_postgres.py index 9b48e9e..5240562 100644 --- a/server/tests/integration/test_data_streaming_guarantees_postgres.py +++ b/server/tests/integration/test_data_streaming_guarantees_postgres.py @@ -31,8 +31,8 @@ from osa.domain.data.model.query_plan import PaginationParams, QueryPlan, TableKind from osa.domain.semantics.model.schema import Schema from osa.domain.semantics.model.value import Cardinality, FieldDefinition, FieldType -from osa.domain.shared.model.srn import Domain, SchemaId -from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore from osa.infrastructure.persistence.metadata_store import PostgresMetadataStore from osa.infrastructure.persistence.repository.schema import ( PostgresSemanticsSchemaRepository, @@ -117,7 +117,7 @@ async def test_large_dump_does_not_materialize_in_python_heap( n = 20_000 await _setup_schema(pg_engine, pg_session) await _bulk_seed(pg_engine, n) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) tracemalloc.start() count = 0 @@ -137,7 +137,7 @@ async def test_early_generator_close_releases_server_cursor( ): await _setup_schema(pg_engine, pg_session) await _bulk_seed(pg_engine, 500) - rs = PostgresDataReadStore(pg_session, Domain("localhost")) + rs = PostgresTableReadStore(pg_session) gen = rs.stream_rows(_records_plan()) first = await gen.__anext__() # opens the server-side cursor diff --git a/server/tests/unit/infrastructure/data/test_cursor_validation.py b/server/tests/unit/infrastructure/data/test_cursor_validation.py index b8eb655..5f2f2cd 100644 --- a/server/tests/unit/infrastructure/data/test_cursor_validation.py +++ b/server/tests/unit/infrastructure/data/test_cursor_validation.py @@ -27,8 +27,8 @@ ) from osa.domain.shared.error import ValidationError from osa.domain.shared.model.hook import ColumnDef -from osa.domain.shared.model.srn import Domain, SchemaId -from osa.infrastructure.data.postgres_data_read_store import PostgresDataReadStore +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore from osa.infrastructure.persistence.feature_table import ( FeatureSchema, build_feature_table, @@ -38,10 +38,10 @@ SCHEMA = SchemaId.parse("compound@1.0.0") -def _store() -> PostgresDataReadStore: +def _store() -> PostgresTableReadStore: # The sort helpers never touch self.session; decoding happens before any DB # access, so a placeholder session is sufficient for these unit tests. - return PostgresDataReadStore(None, Domain("localhost")) # type: ignore[arg-type] + return PostgresTableReadStore(None) # type: ignore[arg-type] def _records_plan(cursor: str) -> QueryPlan: @@ -84,13 +84,13 @@ def test_date_metadata_cursor_value_coerced_to_date(self) -> None: # A FieldType.DATE metadata column is a sa.Date in the dynamic table; # the cursor carries its value as an ISO string — asyncpg rejects a # str bound against DATE, so the decoder must coerce by column type. - value = PostgresDataReadStore._coerce_cursor_value( + value = PostgresTableReadStore._coerce_cursor_value( "2026-01-02", sa.Column("assay_date", sa.Date()) ) assert isinstance(value, date) def test_datetime_metadata_cursor_value_coerced_to_datetime(self) -> None: - value = PostgresDataReadStore._coerce_cursor_value( + value = PostgresTableReadStore._coerce_cursor_value( "2026-01-02T03:04:05+00:00", sa.Column("measured_at", sa.DateTime(timezone=True)) ) assert isinstance(value, datetime) @@ -115,7 +115,7 @@ def test_created_at_cursor_value_coerced_to_datetime(self) -> None: # created_at is an implicit feature column; the type-driven coercion # must turn its ISO string back into a datetime before binding. ft = _feature_table() - value = PostgresDataReadStore._coerce_cursor_value( + value = PostgresTableReadStore._coerce_cursor_value( "2026-01-02T03:04:05+00:00", ft.c.created_at ) assert isinstance(value, datetime) @@ -124,4 +124,4 @@ def test_integer_id_sort_value_coerced_to_int(self) -> None: # Sorting by id, the cursor sort value binds against BIGINT — a str # is coerced (and garbage raises ValueError → 400 via the sort path). ft = _feature_table() - assert PostgresDataReadStore._coerce_cursor_value("7", ft.c.id) == 7 + assert PostgresTableReadStore._coerce_cursor_value("7", ft.c.id) == 7 diff --git a/server/tests/unit/infrastructure/data/test_statement_timeout.py b/server/tests/unit/infrastructure/data/test_statement_timeout.py index 961fba6..c458f4f 100644 --- a/server/tests/unit/infrastructure/data/test_statement_timeout.py +++ b/server/tests/unit/infrastructure/data/test_statement_timeout.py @@ -2,7 +2,7 @@ from datetime import timedelta -from osa.infrastructure.data.postgres_data_read_store import statement_timeout_sql +from osa.infrastructure.data.postgres_table_read_store import statement_timeout_sql def test_statement_timeout_sql_renders_integer_milliseconds() -> None: From a9e5e50a3bc046d1eb27a08b9c92b3d1fe6a15be Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 14:13:01 +0100 Subject: [PATCH 22/23] test: use plain strings for HookName in resolve_table tests (HookName is Annotated[str], not callable) --- .../unit/domain/data/test_data_catalog_service.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/server/tests/unit/domain/data/test_data_catalog_service.py b/server/tests/unit/domain/data/test_data_catalog_service.py index 39f63e8..65dcfc8 100644 --- a/server/tests/unit/domain/data/test_data_catalog_service.py +++ b/server/tests/unit/domain/data/test_data_catalog_service.py @@ -22,7 +22,7 @@ from osa.domain.data.service.data_catalog import DataCatalogService from osa.domain.semantics.model.value import FieldType from osa.domain.shared.error import NotFoundError -from osa.domain.shared.model.ids import HookName, RecordId +from osa.domain.shared.model.ids import RecordId from osa.domain.shared.model.srn import SchemaId RECORDS_COLUMNS = [ @@ -95,7 +95,7 @@ async def test_resolve_table_records() -> None: @pytest.mark.asyncio async def test_resolve_table_feature() -> None: resolved = await _service().resolve_table( - "compound@1.0.0", TableKind.FEATURE, feature_name=HookName("chem_features") + "compound@1.0.0", TableKind.FEATURE, feature_name="chem_features" ) assert resolved.columns == FEATURE_COLUMNS @@ -103,9 +103,7 @@ async def test_resolve_table_feature() -> None: @pytest.mark.asyncio async def test_resolve_table_unknown_feature_404s() -> None: with pytest.raises(NotFoundError) as exc: - await _service().resolve_table( - "compound@1.0.0", TableKind.FEATURE, feature_name=HookName("nope") - ) + await _service().resolve_table("compound@1.0.0", TableKind.FEATURE, feature_name="nope") assert exc.value.code == "table_not_found" @@ -114,6 +112,4 @@ async def test_resolve_table_kind_must_match() -> None: # A resource named "records" exists, but with kind RECORDS — asking for a # FEATURE of that name must not match it. with pytest.raises(NotFoundError): - await _service().resolve_table( - "compound@1.0.0", TableKind.FEATURE, feature_name=HookName("records") - ) + await _service().resolve_table("compound@1.0.0", TableKind.FEATURE, feature_name="records") From 33f8c77d3f2eb518f696d911004c9c9610df4b3f Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 11 Jun 2026 14:20:48 +0100 Subject: [PATCH 23/23] refactor: replace data-adapter module functions with a SchemaFeatureReader and class methods No module-level functions in the /data/ adapters. The shared schema->feature resolution (schema_features.py's four loose functions, two doing session I/O) becomes a SchemaFeatureReader composed by both read stores. Per-file helpers (statement_timeout_sql, _escape_like, _invalid_cursor, _apply_scalar_op, _feature_column_specs) move onto their store classes as static methods. --- .../data/postgres_catalog_read_store.py | 39 +++--- .../data/postgres_table_read_store.py | 129 +++++++++--------- .../data/schema_feature_reader.py | 73 ++++++++++ .../infrastructure/data/schema_features.py | 71 ---------- .../data/test_statement_timeout.py | 15 +- 5 files changed, 164 insertions(+), 163 deletions(-) create mode 100644 server/osa/infrastructure/data/schema_feature_reader.py delete mode 100644 server/osa/infrastructure/data/schema_features.py diff --git a/server/osa/infrastructure/data/postgres_catalog_read_store.py b/server/osa/infrastructure/data/postgres_catalog_read_store.py index a396dc6..6ea0ed8 100644 --- a/server/osa/infrastructure/data/postgres_catalog_read_store.py +++ b/server/osa/infrastructure/data/postgres_catalog_read_store.py @@ -31,10 +31,7 @@ from osa.domain.semantics.model.value import FieldType from osa.domain.shared.model.ids import RecordId from osa.domain.shared.model.srn import Domain, RecordSRN, SchemaId -from osa.infrastructure.data.schema_features import ( - feature_count_stmt, - schema_feature_tables, -) +from osa.infrastructure.data.schema_feature_reader import SchemaFeatureReader from osa.infrastructure.persistence.feature_table import ( FeatureSchema, build_feature_table, @@ -57,16 +54,17 @@ _ALL_FORMATS = ["", "csv", "csv.gz"] -def _escape_like(value: str) -> str: - return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - - class PostgresCatalogReadStore: def __init__(self, session: AsyncSession, node_domain: Domain) -> None: self.session = session # Only the node's DNS domain is needed (to render SRNs in the catalog / # manifest) — not the whole Config. self.node_domain = node_domain + self._features = SchemaFeatureReader(session) + + @staticmethod + def _escape_like(value: str) -> str: + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") # ------------------------------------------------------------------ # # Single record by ID @@ -75,7 +73,7 @@ def __init__(self, session: AsyncSession, node_domain: Domain) -> None: async def get_record_by_id(self, id: RecordId, version: int | None) -> RecordSummary | None: # The records PK is the SRN ``urn:osa:{domain}:rec:{id}@{version}``. # Match the id segment; resolve version (pin or latest published). - pattern = f"urn:osa:%:rec:{_escape_like(str(id))}@%" + pattern = f"urn:osa:%:rec:{self._escape_like(str(id))}@%" t = records_table stmt = ( select(t.c.srn, t.c.schema_id, t.c.schema_version, t.c.published_at, t.c.metadata) @@ -122,7 +120,7 @@ async def get_node_catalog(self) -> NodeCatalog: for short_id, version in schema_rows: schema_id = SchemaId.parse(f"{short_id}@{version}") resources = [TableResourceSummary(name="records", kind=TableKind.RECORDS)] - for hook_name, _ in await schema_feature_tables(self.session, schema_id): + for hook_name, _ in await self._features.feature_tables(schema_id): resources.append(TableResourceSummary(name=hook_name, kind=TableKind.FEATURE)) entries.append( CatalogEntry( @@ -180,18 +178,16 @@ async def get_schema_manifest(self, schema_id: SchemaId) -> SchemaManifest | Non async def _feature_resources(self, schema_id: SchemaId) -> list[TableResource]: """Build a TableResource for each feature table registered on the schema.""" resources: list[TableResource] = [] - for hook_name, fschema in await schema_feature_tables(self.session, schema_id): + for hook_name, fschema in await self._features.feature_tables(schema_id): ft = build_feature_table(hook_name, fschema) - count = int( - (await self.session.execute(feature_count_stmt(ft, schema_id))).scalar_one() - ) + count = await self._features.count_rows(ft, schema_id) resources.append( TableResource( name=hook_name, kind=TableKind.FEATURE, # Implicit columns (id, record_srn, created_at) precede the # hook's declared data columns — this is the CSV header order. - columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *_feature_column_specs(fschema)], + columns=[*IMPLICIT_FEATURE_COLUMN_SPECS, *self._feature_column_specs(fschema)], row_count=count, formats=list(_ALL_FORMATS), ) @@ -220,9 +216,10 @@ async def _records_count(self, schema_id: SchemaId) -> int: ) return int((await self.session.execute(stmt)).scalar_one()) - -def _feature_column_specs(fschema: FeatureSchema) -> list[ColumnSpec]: - """Map a feature table's declared columns to manifest ColumnSpecs.""" - return [ - ColumnSpec(name=c.name, type=_JSON_TYPE_TO_FIELD_TYPE[c.json_type]) for c in fschema.columns - ] + @staticmethod + def _feature_column_specs(fschema: FeatureSchema) -> list[ColumnSpec]: + """Map a feature table's declared columns to manifest ColumnSpecs.""" + return [ + ColumnSpec(name=c.name, type=_JSON_TYPE_TO_FIELD_TYPE[c.json_type]) + for c in fschema.columns + ] diff --git a/server/osa/infrastructure/data/postgres_table_read_store.py b/server/osa/infrastructure/data/postgres_table_read_store.py index aa177de..88b7568 100644 --- a/server/osa/infrastructure/data/postgres_table_read_store.py +++ b/server/osa/infrastructure/data/postgres_table_read_store.py @@ -42,10 +42,7 @@ from osa.domain.shared.error import NotFoundError, ValidationError from osa.domain.shared.model.ids import RecordId from osa.domain.shared.model.srn import RecordSRN, SchemaId -from osa.infrastructure.data.schema_features import ( - feature_schema_scope, - schema_feature_tables, -) +from osa.infrastructure.data.schema_feature_reader import SchemaFeatureReader from osa.infrastructure.persistence.feature_table import ( FeatureSchema, build_feature_table, @@ -64,36 +61,38 @@ logger = logging.getLogger(__name__) -def _escape_like(value: str) -> str: - return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - - -def statement_timeout_sql(timeout: timedelta) -> str: - """Render the caller's execution budget as a ``SET LOCAL`` statement. - - Integer milliseconds, so no operator- or caller-supplied string ever - reaches raw SQL. ``SET LOCAL`` reverts on commit/rollback, so the value - never leaks to a later request reusing the pooled connection. - """ - return f"SET LOCAL statement_timeout = '{int(timeout.total_seconds() * 1000)}ms'" - +class PostgresTableReadStore: + def __init__(self, session: AsyncSession) -> None: + self.session = session + self._features = SchemaFeatureReader(session) -def _invalid_cursor(exc: Exception) -> ValidationError: - """Map a cursor decode/coerce ``ValueError`` to a 400, not a 500. + @staticmethod + def statement_timeout_sql(timeout: timedelta) -> str: + """Render the caller's execution budget as a ``SET LOCAL`` statement. - Decoding a corrupt cursor (bad base64, missing ``s``/``id`` keys) or coercing - a non-conforming value (e.g. a non-integer feature id) raises ``ValueError``, - which is not an ``OSAError`` and would otherwise surface as a generic 500. - The cursor is opaque client-supplied input, so a malformed one is a 400. - """ - return ValidationError( - f"Malformed pagination cursor: {exc}", field="cursor", code="invalid_cursor" - ) + Integer milliseconds, so no operator- or caller-supplied string ever + reaches raw SQL. ``SET LOCAL`` reverts on commit/rollback, so the value + never leaks to a later request reusing the pooled connection. + """ + return f"SET LOCAL statement_timeout = '{int(timeout.total_seconds() * 1000)}ms'" + @staticmethod + def _escape_like(value: str) -> str: + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -class PostgresTableReadStore: - def __init__(self, session: AsyncSession) -> None: - self.session = session + @staticmethod + def _invalid_cursor(exc: Exception) -> ValidationError: + """Map a cursor decode/coerce ``ValueError`` to a 400, not a 500. + + Decoding a corrupt cursor (bad base64, missing ``s``/``id`` keys) or + coercing a non-conforming value (e.g. a non-integer feature id) raises + ``ValueError``, which is not an ``OSAError`` and would otherwise surface + as a generic 500. The cursor is opaque client-supplied input, so a + malformed one is a 400. + """ + return ValidationError( + f"Malformed pagination cursor: {exc}", field="cursor", code="invalid_cursor" + ) # ------------------------------------------------------------------ # # Streaming primitive @@ -105,7 +104,7 @@ async def stream_rows( # The generator body runs on the route's pre-flight ``__anext__`` pull, # so the timeout is in place before the SELECT opens its cursor. if timeout is not None: - await self.session.execute(sa.text(statement_timeout_sql(timeout))) + await self.session.execute(sa.text(self.statement_timeout_sql(timeout))) if plan.table_kind == TableKind.RECORDS: async for row in self._stream_records(plan): yield row @@ -221,7 +220,7 @@ def _cursor_after( ) ) except ValueError as exc: - raise _invalid_cursor(exc) from exc + raise self._invalid_cursor(exc) from exc @staticmethod def _coerce_cursor_value(value: Any, sort_expr: sa.ColumnElement[Any]) -> Any: @@ -266,7 +265,7 @@ async def _stream_features(self, plan: QueryPlan) -> AsyncIterator[Mapping[str, # A features. table is shared by every convention that registers # the hook name, across schemas — scope to the requested schema's # records via the records join. - conditions.extend(feature_schema_scope(plan.schema_id)) + conditions.extend(self._features.records_scope(plan.schema_id)) # Implicit columns (id, record_srn, created_at) precede the hook's # declared data columns — this is the CSV header order. @@ -294,7 +293,7 @@ async def _resolve_feature_table( it. Streaming a feature not registered on the schema is a 404, not a leak of another schema's table. """ - for hook_name, fschema in await schema_feature_tables(self.session, schema_id): + for hook_name, fschema in await self._features.feature_tables(schema_id): if hook_name == feature_name: return build_feature_table(hook_name, fschema), fschema raise NotFoundError( @@ -366,7 +365,9 @@ def _compile_feature_predicate( field=predicate.field.dotted(), code="unknown_feature_column", ) - return _apply_scalar_op(ft.c[predicate.field.column], predicate.op, predicate.value) + return self._apply_scalar_op( + ft.c[predicate.field.column], predicate.op, predicate.value + ) if isinstance(predicate.field, MetadataFieldRef): raise ValidationError( "Metadata-field predicates are not supported on a feature stream.", @@ -411,7 +412,7 @@ def _compile_predicate(self, predicate: Predicate, *, metadata_t: Any) -> Any: code="unknown_metadata_field", ) col = metadata_t.c[predicate.field.field] - return _apply_scalar_op(col, predicate.op, predicate.value) + return self._apply_scalar_op(col, predicate.op, predicate.value) if isinstance(predicate.field, FeatureFieldRef): # Feature predicates in the records stream are a US5 extension. raise ValidationError( @@ -421,30 +422,34 @@ def _compile_predicate(self, predicate: Predicate, *, metadata_t: Any) -> Any: ) raise TypeError(f"Unexpected field ref type: {type(predicate.field).__name__}") - -def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: - if op == FilterOperator.EQ: - return col == value - if op == FilterOperator.NEQ: - return or_(col != value, col.is_(None)) - if op == FilterOperator.GT: - return col > value - if op == FilterOperator.GTE: - return col >= value - if op == FilterOperator.LT: - return col < value - if op == FilterOperator.LTE: - return col <= value - if op == FilterOperator.IN: - if not isinstance(value, list): - raise ValidationError( - "Operator 'in' requires a list value.", field=col.key, code="invalid_value_for_op" + @staticmethod + def _apply_scalar_op(col: Any, op: FilterOperator, value: Any) -> Any: + if op == FilterOperator.EQ: + return col == value + if op == FilterOperator.NEQ: + return or_(col != value, col.is_(None)) + if op == FilterOperator.GT: + return col > value + if op == FilterOperator.GTE: + return col >= value + if op == FilterOperator.LT: + return col < value + if op == FilterOperator.LTE: + return col <= value + if op == FilterOperator.IN: + if not isinstance(value, list): + raise ValidationError( + "Operator 'in' requires a list value.", + field=col.key, + code="invalid_value_for_op", + ) + return col.in_(value) + if op == FilterOperator.CONTAINS: + return cast(col, String).ilike( + f"%{PostgresTableReadStore._escape_like(str(value))}%", escape="\\" ) - return col.in_(value) - if op == FilterOperator.CONTAINS: - return cast(col, String).ilike(f"%{_escape_like(str(value))}%", escape="\\") - if op == FilterOperator.IS_NULL: - return col.is_(None) - raise ValidationError( - f"Unsupported operator: {op}", field="filter", code="unsupported_operator" - ) + if op == FilterOperator.IS_NULL: + return col.is_(None) + raise ValidationError( + f"Unsupported operator: {op}", field="filter", code="unsupported_operator" + ) diff --git a/server/osa/infrastructure/data/schema_feature_reader.py b/server/osa/infrastructure/data/schema_feature_reader.py new file mode 100644 index 0000000..5280f65 --- /dev/null +++ b/server/osa/infrastructure/data/schema_feature_reader.py @@ -0,0 +1,73 @@ +"""Reads which feature tables belong to a schema (via its conventions). + +A ``features.`` table is global (UNIQUE(hook_name)) and shared by every +convention that registers the hook name, across schemas. This reader answers +"which feature tables does this schema expose" (through its conventions) and +counts a feature table's rows scoped to one schema's records. Composed by both +``/data/`` read adapters (table streaming + catalog/manifest). +""" + +from __future__ import annotations + +from typing import Any + +import sqlalchemy as sa +from sqlalchemy import and_, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.shared.model.srn import SchemaId +from osa.infrastructure.persistence.feature_table import FeatureSchema +from osa.infrastructure.persistence.tables import ( + conventions_table, + feature_tables_table, + records_table, +) + + +class SchemaFeatureReader: + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def feature_tables(self, schema_id: SchemaId) -> list[tuple[str, FeatureSchema]]: + """(hook_name, FeatureSchema) for every materialized feature table on the schema.""" + hook_names = await self._hook_names(schema_id) + if not hook_names: + return [] + stmt = select( + feature_tables_table.c.hook_name, feature_tables_table.c.feature_schema + ).where(feature_tables_table.c.hook_name.in_(hook_names)) + result = await self.session.execute(stmt) + return [ + (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) + for row in result.mappings() + ] + + async def count_rows(self, ft: sa.Table, schema_id: SchemaId) -> int: + """Row count of a feature table scoped to the schema's records.""" + stmt = ( + select(func.count()) + .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) + .where(and_(*self.records_scope(schema_id))) + ) + return int((await self.session.execute(stmt)).scalar_one()) + + @staticmethod + def records_scope(schema_id: SchemaId) -> list[Any]: + """Records-join conditions scoping a shared feature table to one schema.""" + return [ + records_table.c.schema_id == schema_id.id.root, + records_table.c.schema_version == schema_id.version.root, + ] + + async def _hook_names(self, schema_id: SchemaId) -> set[str]: + """Hook names registered on the schema's conventions (the schema→feature link).""" + stmt = select(conventions_table.c.hooks).where( + conventions_table.c.schema_id == schema_id.id.root, + conventions_table.c.schema_version == schema_id.version.root, + ) + result = await self.session.execute(stmt) + names: set[str] = set() + for (hooks,) in result.all(): + for hook in hooks or []: + names.add(hook["name"]) + return names diff --git a/server/osa/infrastructure/data/schema_features.py b/server/osa/infrastructure/data/schema_features.py deleted file mode 100644 index 586402e..0000000 --- a/server/osa/infrastructure/data/schema_features.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Shared schema→feature-table resolution used by both /data/ read adapters. - -A ``features.`` table is global (UNIQUE(hook_name)) and shared by every -convention that registers the hook name, across schemas. These helpers answer -"which feature tables belong to this schema" (via its conventions) and provide -the records-join condition that scopes feature rows to one schema's records. -""" - -from __future__ import annotations - -from typing import Any - -import sqlalchemy as sa -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from osa.domain.shared.model.srn import SchemaId -from osa.infrastructure.persistence.feature_table import FeatureSchema -from osa.infrastructure.persistence.tables import ( - conventions_table, - feature_tables_table, - records_table, -) - - -async def schema_hook_names(session: AsyncSession, schema_id: SchemaId) -> set[str]: - """Hook names registered on the schema's conventions (the schema→feature link).""" - stmt = select(conventions_table.c.hooks).where( - conventions_table.c.schema_id == schema_id.id.root, - conventions_table.c.schema_version == schema_id.version.root, - ) - result = await session.execute(stmt) - names: set[str] = set() - for (hooks,) in result.all(): - for hook in hooks or []: - names.add(hook["name"]) - return names - - -async def schema_feature_tables( - session: AsyncSession, schema_id: SchemaId -) -> list[tuple[str, FeatureSchema]]: - """(hook_name, FeatureSchema) for every materialized feature table on the schema.""" - hook_names = await schema_hook_names(session, schema_id) - if not hook_names: - return [] - stmt = select(feature_tables_table.c.hook_name, feature_tables_table.c.feature_schema).where( - feature_tables_table.c.hook_name.in_(hook_names) - ) - result = await session.execute(stmt) - return [ - (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"])) - for row in result.mappings() - ] - - -def feature_schema_scope(schema_id: SchemaId) -> list[Any]: - """Records-join conditions scoping shared feature tables to one schema.""" - return [ - records_table.c.schema_id == schema_id.id.root, - records_table.c.schema_version == schema_id.version.root, - ] - - -def feature_count_stmt(ft: sa.Table, schema_id: SchemaId) -> sa.Select[Any]: - """COUNT of a feature table's rows scoped to the schema's records.""" - return ( - select(sa.func.count()) - .select_from(ft.join(records_table, records_table.c.srn == ft.c.record_srn)) - .where(sa.and_(*feature_schema_scope(schema_id))) - ) diff --git a/server/tests/unit/infrastructure/data/test_statement_timeout.py b/server/tests/unit/infrastructure/data/test_statement_timeout.py index c458f4f..de46ab2 100644 --- a/server/tests/unit/infrastructure/data/test_statement_timeout.py +++ b/server/tests/unit/infrastructure/data/test_statement_timeout.py @@ -2,19 +2,16 @@ from datetime import timedelta -from osa.infrastructure.data.postgres_table_read_store import statement_timeout_sql +from osa.infrastructure.data.postgres_table_read_store import PostgresTableReadStore + +_sql = PostgresTableReadStore.statement_timeout_sql def test_statement_timeout_sql_renders_integer_milliseconds() -> None: - assert statement_timeout_sql(timedelta(seconds=30)) == "SET LOCAL statement_timeout = '30000ms'" - assert ( - statement_timeout_sql(timedelta(minutes=30)) == "SET LOCAL statement_timeout = '1800000ms'" - ) + assert _sql(timedelta(seconds=30)) == "SET LOCAL statement_timeout = '30000ms'" + assert _sql(timedelta(minutes=30)) == "SET LOCAL statement_timeout = '1800000ms'" def test_statement_timeout_sql_truncates_subsecond() -> None: # int milliseconds — no operator-supplied string ever reaches raw SQL. - assert ( - statement_timeout_sql(timedelta(milliseconds=1500.7)) - == "SET LOCAL statement_timeout = '1500ms'" - ) + assert _sql(timedelta(milliseconds=1500.7)) == "SET LOCAL statement_timeout = '1500ms'"