From 42ecc5566800387e374aa88a831f7b8286db46e6 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Fri, 26 Jun 2026 11:31:51 +0200 Subject: [PATCH 1/8] prototype using hybrid model for orm --- .pre-commit-config.yaml | 1 + src/database/setup.py | 30 +++++++++++++++++++++++++++ src/main.py | 4 +++- tests/routers/openml/task_tag_test.py | 19 +++++++++++++++++ 4 files changed, 53 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1601d30..4d9aefd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,7 @@ repos: additional_dependencies: - fastapi - pytest + - sqlalchemy - repo: https://github.com/astral-sh/ruff-pre-commit rev: 'v0.15.18' diff --git a/src/database/setup.py b/src/database/setup.py index cfd2306..5d37e79 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -3,10 +3,33 @@ from loguru import logger from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.ext.declarative import DeferredReflection +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from config import DatabaseConfiguration, get_config +class Base(DeclarativeBase): + pass + + +class ExpDBReflected(DeferredReflection): + __abstract__ = True + + +class UserDBReflected(DeferredReflection): + __abstract__ = True + + +class UserR(UserDBReflected, Base): + __tablename__ = "users" + id: Mapped[int] = mapped_column(primary_key=True) + + +class TaskTag(ExpDBReflected, Base): + __tablename__ = "task_tag" + + def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine: db_url = URL.create( drivername=db_config.drivername, @@ -35,6 +58,13 @@ def expdb_database() -> AsyncEngine: return _create_engine(get_config().expdb_database) +async def reflect_db_schemas() -> None: + async with user_database().connect() as connection: + await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it. + async with expdb_database().connect() as connection: + await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above. + + async def close_databases() -> None: """Close all database connections.""" for db in (user_database, expdb_database): diff --git a/src/main.py b/src/main.py index b4f7dc8..eb0a495 100644 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,7 @@ request_response_logger, setup_log_sinks, ) -from database.setup import close_databases +from database.setup import close_databases, reflect_db_schemas from routers.openml.datasets import router as datasets_router from routers.openml.estimation_procedure import router as estimationprocedure_router from routers.openml.evaluations import router as evaluationmeasures_router @@ -45,6 +45,8 @@ async def lifespan( app: FastAPI | None, # noqa: ARG001 # parameter required by FastAPI/Starlette ) -> AsyncIterator[None]: """Manage application lifespan - startup and shutdown events.""" + logger.info("Reflecting database schemas") + await reflect_db_schemas() yield await asyncio.gather( logger.complete(), diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index 087f76e..be392bd 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -2,8 +2,11 @@ from typing import TYPE_CHECKING import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError +from database.setup import TaskTag, UserR from database.tasks import get_tags from database.users import User from routers.openml.tasks import tag_task @@ -84,6 +87,22 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection assert e.value.code == task_not_found_in_tag_endpoint +async def test_if_reflection_works( + py_api: httpx.Client, expdb_test: AsyncConnection, user_test: AsyncConnection +) -> None: + _ = py_api + user_stmt = select(UserR).where(UserR.id == 1) + async with AsyncSession(bind=user_test) as session: + user = (await session.scalars(user_stmt)).first() + assert user.username == "emily12" + + tag_stmt = select(TaskTag).where(TaskTag.tag == "OpenML100") + async with AsyncSession(bind=expdb_test) as session: + tags = (await session.scalars(tag_stmt)).all() + tag_count = 100 + assert len(tags) == tag_count + + # -- migration tests -- From 339390a2e4bd3d386f35103738ce59e471738b13 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Tue, 30 Jun 2026 13:23:15 +0200 Subject: [PATCH 2/8] Create the TaskTag ORM class, re-organize the reflection code --- src/database/schema/__init__.py | 1 + src/database/schema/base.py | 47 +++++++++++++++++++++++++++ src/database/schema/tags.py | 30 +++++++++++++++++ src/database/setup.py | 30 ----------------- src/main.py | 3 +- tests/routers/openml/task_tag_test.py | 10 ++---- 6 files changed, 82 insertions(+), 39 deletions(-) create mode 100644 src/database/schema/__init__.py create mode 100644 src/database/schema/base.py create mode 100644 src/database/schema/tags.py diff --git a/src/database/schema/__init__.py b/src/database/schema/__init__.py new file mode 100644 index 0000000..6e6260c --- /dev/null +++ b/src/database/schema/__init__.py @@ -0,0 +1 @@ +"""Defines Object-Relational Mappings (ORM).""" diff --git a/src/database/schema/base.py b/src/database/schema/base.py new file mode 100644 index 0000000..498dfb5 --- /dev/null +++ b/src/database/schema/base.py @@ -0,0 +1,47 @@ +"""Base classes for all ORM classes. + +When defining a new ORM class, use both `Base` and one of the `DeferredReflection` subclasses to +make sure that the class is populated with attributes that may not be defined explicitly. +For example, when creating a new mapping for a table from the `openml_expdb` database, use: + +class ClassName(ExpDBReflected, Base): + __tablename__ = "class_names" + + # any columns you wanted mapped explicitly + ... + +""" + +from sqlalchemy.ext.declarative import DeferredReflection +from sqlalchemy.orm import DeclarativeBase + +from database.setup import expdb_database, user_database + + +class Base(DeclarativeBase): + """Base class for all ORM classes.""" + + +class ExpDBReflected(DeferredReflection): + """Base class for ORM classes to map onto a table in the `openml_expdb` database.""" + + __abstract__ = True + + +class UserDBReflected(DeferredReflection): + """Base class for ORM classes to map onto a table in the `openml` database.""" + + __abstract__ = True + + +async def reflect_db_schemas() -> None: + """Populate defined ORM classes with attributes defined from columns in the database. + + For example, the `dataset` class would automatically get a `collection_date` attribute, + even if it wasn't explicitly declared in the class definition, + because the `openml_expdb.dataset` table has a column `collection_date`. + """ + async with user_database().connect() as connection: + await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it. + async with expdb_database().connect() as connection: + await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above. diff --git a/src/database/schema/tags.py b/src/database/schema/tags.py new file mode 100644 index 0000000..a6b6a78 --- /dev/null +++ b/src/database/schema/tags.py @@ -0,0 +1,30 @@ +"""ORM classes for the *_tag tables (task_tag, ...).""" + +from datetime import datetime + +from sqlalchemy import FetchedValue +from sqlalchemy.orm import Mapped, mapped_column + +from database.schema.base import Base, ExpDBReflected +from routers.types import Identifier, TagString + + +class Tag: + """Base class for all of the *_tag tables.""" + + # The identifier of the entity that is tagged (e.g., dataset id, task id) + entity_id: Mapped[Identifier] = mapped_column("id", primary_key=True) + tag: Mapped[TagString] = mapped_column(primary_key=True) + uploader_id: Mapped[Identifier] = mapped_column("uploader") + creation_date: Mapped[datetime] = mapped_column("date", server_default=FetchedValue()) + + +class TaskTag(ExpDBReflected, Tag, Base): + """Tags belonging to a task.""" + + __tablename__ = "task_tag" + + @property + def task_id(self) -> Identifier: + """Identifier of the task which is tagged by this tag.""" + return self.entity_id diff --git a/src/database/setup.py b/src/database/setup.py index 5d37e79..cfd2306 100644 --- a/src/database/setup.py +++ b/src/database/setup.py @@ -3,33 +3,10 @@ from loguru import logger from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlalchemy.ext.declarative import DeferredReflection -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from config import DatabaseConfiguration, get_config -class Base(DeclarativeBase): - pass - - -class ExpDBReflected(DeferredReflection): - __abstract__ = True - - -class UserDBReflected(DeferredReflection): - __abstract__ = True - - -class UserR(UserDBReflected, Base): - __tablename__ = "users" - id: Mapped[int] = mapped_column(primary_key=True) - - -class TaskTag(ExpDBReflected, Base): - __tablename__ = "task_tag" - - def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine: db_url = URL.create( drivername=db_config.drivername, @@ -58,13 +35,6 @@ def expdb_database() -> AsyncEngine: return _create_engine(get_config().expdb_database) -async def reflect_db_schemas() -> None: - async with user_database().connect() as connection: - await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it. - async with expdb_database().connect() as connection: - await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above. - - async def close_databases() -> None: """Close all database connections.""" for db in (user_database, expdb_database): diff --git a/src/main.py b/src/main.py index eb0a495..6207b15 100644 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,8 @@ request_response_logger, setup_log_sinks, ) -from database.setup import close_databases, reflect_db_schemas +from database.schema.base import reflect_db_schemas +from database.setup import close_databases from routers.openml.datasets import router as datasets_router from routers.openml.estimation_procedure import router as estimationprocedure_router from routers.openml.evaluations import router as evaluationmeasures_router diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index be392bd..cd58964 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError -from database.setup import TaskTag, UserR +from database.schema.tags import TaskTag from database.tasks import get_tags from database.users import User from routers.openml.tasks import tag_task @@ -87,14 +87,8 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection assert e.value.code == task_not_found_in_tag_endpoint -async def test_if_reflection_works( - py_api: httpx.Client, expdb_test: AsyncConnection, user_test: AsyncConnection -) -> None: +async def test_if_reflection_works(py_api: httpx.Client, expdb_test: AsyncConnection) -> None: _ = py_api - user_stmt = select(UserR).where(UserR.id == 1) - async with AsyncSession(bind=user_test) as session: - user = (await session.scalars(user_stmt)).first() - assert user.username == "emily12" tag_stmt = select(TaskTag).where(TaskTag.tag == "OpenML100") async with AsyncSession(bind=expdb_test) as session: From cc00677e3af36161aca0b4316f61f70fb5861101 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Tue, 30 Jun 2026 13:54:48 +0200 Subject: [PATCH 3/8] Fix typing of _database property --- src/database/users.py | 5 ++++- tests/users.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/database/users.py b/src/database/users.py index 381f01d..26a5f3d 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -106,9 +106,9 @@ async def get_user_groups_for( @dataclasses.dataclass class User: user_id: Identifier - _database: AsyncConnection first_name: str = "" last_name: str = "" + _database: AsyncConnection | None = None _groups: list[UserGroup] | None = None def __post_init__(self) -> None: @@ -136,6 +136,9 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: return None async def get_groups(self) -> list[UserGroup]: + if self._database is None: + msg = "`get_groups` can only be used when `connection` is provided on instantiation." + raise RuntimeError(msg) if self._groups is None: self._groups = await get_user_groups_for( user_id=self.user_id, diff --git a/tests/users.py b/tests/users.py index c98ffb0..07d3c66 100644 --- a/tests/users.py +++ b/tests/users.py @@ -3,10 +3,10 @@ from database.users import User, UserGroup NO_USER = None -SOME_USER = User(user_id=2, _database=None, _groups=[UserGroup.READ_WRITE]) -OWNER_USER = User(user_id=3229, _database=None, _groups=[UserGroup.READ_WRITE]) -DATASET_130_OWNER = User(user_id=16, _database=None, _groups=[UserGroup.READ_WRITE]) -ADMIN_USER = User(user_id=1159, _database=None, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE]) +SOME_USER = User(user_id=2, _groups=[UserGroup.READ_WRITE]) +OWNER_USER = User(user_id=3229, _groups=[UserGroup.READ_WRITE]) +DATASET_130_OWNER = User(user_id=16, _groups=[UserGroup.READ_WRITE]) +ADMIN_USER = User(user_id=1159, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE]) class ApiKey(StrEnum): From 0b201ee35d346e6c78685e6e790654640bbaf112 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Tue, 30 Jun 2026 15:17:38 +0200 Subject: [PATCH 4/8] Fix type errors --- src/core/formatting.py | 9 ++---- src/database/datasets.py | 14 +++++---- src/database/evaluations.py | 15 +++++----- src/database/flows.py | 27 +++++++++-------- src/database/runs.py | 24 +++++++-------- src/database/schema/base.py | 5 ++++ src/database/setups.py | 7 +++-- src/database/studies.py | 20 +++++-------- src/database/tasks.py | 42 ++++++++++++--------------- src/database/users.py | 13 +++++---- src/routers/openml/datasets.py | 9 +++--- src/routers/openml/runs.py | 23 +++++++-------- src/routers/openml/study.py | 29 ++++++++++-------- src/routers/openml/tasks.py | 16 ++++++---- tests/routers/openml/runs_get_test.py | 2 +- 15 files changed, 128 insertions(+), 127 deletions(-) diff --git a/src/core/formatting.py b/src/core/formatting.py index faf6d42..a4b09e8 100644 --- a/src/core/formatting.py +++ b/src/core/formatting.py @@ -1,12 +1,9 @@ import html -from typing import TYPE_CHECKING from config import get_config +from database.schema.base import UntypedRow from schemas.datasets.openml import DatasetFileFormat -if TYPE_CHECKING: - from sqlalchemy.engine import Row - def _str_to_bool(string: str) -> bool: if string.casefold() in ["true", "1", "yes", "y"]: @@ -17,7 +14,7 @@ def _str_to_bool(string: str) -> bool: raise ValueError(msg) -def _format_parquet_url(dataset: Row) -> str | None: +def _format_parquet_url(dataset: UntypedRow) -> str | None: if dataset.format.lower() != DatasetFileFormat.ARFF: return None @@ -27,7 +24,7 @@ def _format_parquet_url(dataset: Row) -> str | None: return f"{minio_base_url}datasets/{ten_thousands_prefix}/{padded_id}/dataset_{dataset.did}.pq" -def _format_dataset_url(dataset: Row) -> str: +def _format_dataset_url(dataset: UntypedRow) -> str: base_url = get_config().routing.server_url filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}" return f"{base_url}data/v1/download/{dataset.file_id}/{filename}" diff --git a/src/database/datasets.py b/src/database/datasets.py index 5f9c0e5..5cb99d0 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -13,15 +13,15 @@ DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from database.schema.base import UntypedRow from routers.types import Identifier, TagString from schemas.datasets.openml import DatasetStatus, Feature if TYPE_CHECKING: - from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection -async def get(id_: Identifier, connection: AsyncConnection) -> Row | None: +async def get(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -35,7 +35,7 @@ async def get(id_: Identifier, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None: +async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -53,7 +53,7 @@ async def get_tag( dataset_id: Identifier, tag: TagString, connection: AsyncConnection, -) -> Row | None: +) -> UntypedRow | None: return ( await connection.execute( text( @@ -111,6 +111,8 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) }, ) except IntegrityError as e: + if e.orig is None: + raise code, msg = e.orig.args if code == _FOREIGN_KEY_CONSTRAINT_FAILED: raise ForeignKeyConstraintError(msg) from e @@ -122,7 +124,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) async def get_description( id_: Identifier, connection: AsyncConnection, -) -> Row | None: +) -> UntypedRow | None: """Get the most recent description for the dataset.""" row = await connection.execute( text( @@ -160,7 +162,7 @@ async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetSta async def get_latest_processing_update( dataset_id: Identifier, connection: AsyncConnection, -) -> Row | None: +) -> UntypedRow | None: row = await connection.execute( text( """ diff --git a/src/database/evaluations.py b/src/database/evaluations.py index 382653f..14bec4d 100644 --- a/src/database/evaluations.py +++ b/src/database/evaluations.py @@ -1,16 +1,20 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sqlalchemy import Row, text +from sqlalchemy import text from core.formatting import _str_to_bool +from database.schema.base import UntypedRow from schemas.datasets.openml import EstimationProcedure if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]: +async def get_math_functions( + function_type: str, + connection: AsyncConnection, +) -> Sequence[UntypedRow]: rows = await connection.execute( text( """ @@ -21,10 +25,7 @@ async def get_math_functions(function_type: str, connection: AsyncConnection) -> ), parameters={"function_type": function_type}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]: diff --git a/src/database/flows.py b/src/database/flows.py index 7b1da5e..830a386 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -1,15 +1,16 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sqlalchemy import Row, text +from sqlalchemy import text +from database.schema.base import UntypedRow from routers.types import Identifier if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]: +async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -20,10 +21,7 @@ async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence ), parameters={"flow_id": for_flow}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]: @@ -41,7 +39,7 @@ async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]: return [tag.tag for tag in tag_rows] -async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]: +async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -52,13 +50,14 @@ async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequenc ), parameters={"flow_id": flow_id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None: +async def get_by_name( + name: str, + external_version: str, + expdb: AsyncConnection, +) -> UntypedRow | None: """Get flow by name and external version.""" row = await expdb.execute( text( @@ -73,7 +72,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) return row.one_or_none() -async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: +async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None: row = await expdb.execute( text( """ diff --git a/src/database/runs.py b/src/database/runs.py index f0bc419..0c5c93c 100644 --- a/src/database/runs.py +++ b/src/database/runs.py @@ -3,8 +3,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, cast -from sqlalchemy import Row, bindparam, text +from sqlalchemy import bindparam, text +from database.schema.base import UntypedRow from routers.types import Identifier if TYPE_CHECKING: @@ -26,7 +27,7 @@ async def exist(id_: Identifier, expdb: AsyncConnection) -> bool: return bool(row.one_or_none()) -async def get(run_id: Identifier, expdb: AsyncConnection) -> Row | None: +async def get(run_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None: """Fetch the core run row from the `run` table. Returns the row if found, or None if no run with `run_id` exists. @@ -63,7 +64,7 @@ async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]: return [row.tag for row in rows.all()] -async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]: +async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]: """Fetch the dataset(s) used as input for a run, with name and url. Joins `input_data` with `dataset` to include the dataset name and ARFF URL. @@ -79,10 +80,10 @@ async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]: ), parameters={"run_id": run_id}, ) - return cast("list[Row]", rows.all()) + return cast("list[UntypedRow]", rows.all()) -async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]: +async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]: """Fetch output files attached to a run from the `runfile` table. Typical entries include the description XML and predictions ARFF. @@ -98,7 +99,7 @@ async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]: ), parameters={"run_id": run_id}, ) - return cast("list[Row]", rows.all()) + return cast("list[UntypedRow]", rows.all()) async def get_evaluations( @@ -106,7 +107,7 @@ async def get_evaluations( expdb: AsyncConnection, *, evaluation_engine_ids: list[int], -) -> list[Row]: +) -> list[UntypedRow]: """Fetch evaluation metric results for a run. Joins `evaluation` with `math_function` to resolve the metric name @@ -138,10 +139,10 @@ async def get_evaluations( query, parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids}, ) - return cast("list[Row]", rows.all()) + return cast("list[UntypedRow]", rows.all()) -async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[UntypedRow]: """Get trace rows for a run from the trace table.""" rows = await expdb.execute( text( @@ -153,7 +154,4 @@ async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: ), parameters={"run_id": run_id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() diff --git a/src/database/schema/base.py b/src/database/schema/base.py index 498dfb5..0a75055 100644 --- a/src/database/schema/base.py +++ b/src/database/schema/base.py @@ -12,11 +12,16 @@ class ClassName(ExpDBReflected, Base): """ +from typing import Any + +from sqlalchemy import Row from sqlalchemy.ext.declarative import DeferredReflection from sqlalchemy.orm import DeclarativeBase from database.setup import expdb_database, user_database +UntypedRow = Row[Any] + class Base(DeclarativeBase): """Base class for all ORM classes.""" diff --git a/src/database/setups.py b/src/database/setups.py index 1c959b3..1cf6c22 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -4,14 +4,15 @@ from sqlalchemy import text +from database.schema.base import UntypedRow from routers.types import Identifier, TagString if TYPE_CHECKING: - from sqlalchemy.engine import Row, RowMapping + from sqlalchemy.engine import RowMapping from sqlalchemy.ext.asyncio import AsyncConnection -async def get(setup_id: Identifier, connection: AsyncConnection) -> Row | None: +async def get(setup_id: Identifier, connection: AsyncConnection) -> UntypedRow | None: """Get the setup with id `setup_id` from the database.""" row = await connection.execute( text( @@ -53,7 +54,7 @@ async def get_parameters(setup_id: Identifier, connection: AsyncConnection) -> l return list(rows.mappings().all()) -async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[Row]: +async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[UntypedRow]: """Get all tags for setup with `setup_id` from the database.""" rows = await connection.execute( text( diff --git a/src/database/studies.py b/src/database/studies.py index 1939452..b4a6936 100644 --- a/src/database/studies.py +++ b/src/database/studies.py @@ -3,8 +3,9 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, cast -from sqlalchemy import Row, text +from sqlalchemy import text +from database.schema.base import UntypedRow from database.users import User from routers.types import Identifier from schemas.study import CreateStudy, StudyType @@ -13,7 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection -async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None: +async def get_by_id(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -27,7 +28,7 @@ async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None: +async def get_by_alias(alias: str, connection: AsyncConnection) -> UntypedRow | None: row = await connection.execute( text( """ @@ -41,7 +42,7 @@ async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: +async def get_study_data(study: UntypedRow, expdb: AsyncConnection) -> Sequence[UntypedRow]: """Return data related to the study, content depends on the study type. For task studies: (task id, dataset id) @@ -58,10 +59,8 @@ async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: ), parameters={"study_id": study.id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() + rows = await expdb.execute( text( """ @@ -80,10 +79,7 @@ async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]: ), parameters={"study_id": study.id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int: diff --git a/src/database/tasks.py b/src/database/tasks.py index b0f6010..0910de4 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,7 +1,7 @@ from collections.abc import Sequence -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING -from sqlalchemy import Row, text +from sqlalchemy import text from sqlalchemy.exc import IntegrityError from database.exceptions import ( @@ -10,13 +10,14 @@ DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from database.schema.base import UntypedRow from routers.types import Identifier, TagString if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncConnection -async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: +async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None: row = await expdb.execute( text( """ @@ -30,7 +31,7 @@ async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None: return row.one_or_none() -async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: +async def get_task_types(expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -39,13 +40,10 @@ async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]: """, ), ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> Row | None: +async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None: row = await expdb.execute( text( """ @@ -102,7 +100,10 @@ async def get_task_evaluation_measure(task_id: int, expdb: AsyncConnection) -> s return result.value if result else None -async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]: +async def get_input_for_task_type( + task_type_id: int, + expdb: AsyncConnection, +) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -113,13 +114,10 @@ async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> ), parameters={"ttid": task_type_id}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() -async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[Row]: +async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -130,16 +128,13 @@ async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequenc ), parameters={"task_id": id_}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_task_type_inout_with_template( task_type: Identifier, expdb: AsyncConnection, -) -> Sequence[Row]: +) -> Sequence[UntypedRow]: rows = await expdb.execute( text( """ @@ -150,10 +145,7 @@ async def get_task_type_inout_with_template( ), parameters={"ttid": task_type}, ) - return cast( - "Sequence[Row]", - rows.all(), - ) + return rows.all() async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]: @@ -193,6 +185,8 @@ async def tag( }, ) except IntegrityError as e: + if e.orig is None: + raise code, msg = e.orig.args if code == _FOREIGN_KEY_CONSTRAINT_FAILED: raise ForeignKeyConstraintError(msg) from e diff --git a/src/database/users.py b/src/database/users.py index 26a5f3d..379c6f8 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -136,14 +136,17 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None: return None async def get_groups(self) -> list[UserGroup]: + if self._groups: + return self._groups + if self._database is None: msg = "`get_groups` can only be used when `connection` is provided on instantiation." raise RuntimeError(msg) - if self._groups is None: - self._groups = await get_user_groups_for( - user_id=self.user_id, - connection=self._database, - ) + + self._groups = await get_user_groups_for( + user_id=self.user_id, + connection=self._database, + ) return self._groups async def is_admin(self) -> bool: diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 270c71b..72c94de 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -156,6 +156,7 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") async def list_datasets( # noqa: PLR0913, C901 + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[CasualString128 | None, Body()] = None, tag: Annotated[TagString | None, Body()] = None, @@ -180,9 +181,7 @@ async def list_datasets( # noqa: PLR0913, C901 number_missing_values: Annotated[IntegerRange | None, Body()] = None, status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE, user: Annotated[User | None, Depends(fetch_user)] = None, - expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: - assert expdb_db is not None # noqa: S101 status_subquery = text( """ SELECT ds1.`did`, ds1.`status` @@ -356,8 +355,8 @@ async def _get_dataset_raise_otherwise( @router.get("/features/{dataset_id}", response_model_exclude_none=True) async def get_dataset_features( dataset_id: Identifier, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[Feature]: assert expdb is not None # noqa: S101 await _get_dataset_raise_otherwise(dataset_id, user, expdb) @@ -453,9 +452,9 @@ async def update_dataset_status( ) async def get_dataset( dataset_id: Identifier, + user_db: Annotated[AsyncConnection, Depends(userdb_connection)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], user: Annotated[User | None, Depends(fetch_user)] = None, - user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None, - expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> DatasetMetadata: assert user_db is not None # noqa: S101 assert expdb_db is not None # noqa: S101 diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 0abf40e..88296cf 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -6,9 +6,6 @@ from fastapi import APIRouter, Depends -if TYPE_CHECKING: - from sqlalchemy import Row - import config import database.flows import database.runs @@ -16,6 +13,7 @@ import database.tasks import database.users from core.errors import RunNotFoundError, RunTraceNotFoundError +from database.schema.base import UntypedRow from routers.dependencies import expdb_connection, userdb_connection from routers.types import Identifier from schemas.runs import ( @@ -72,17 +70,17 @@ class RunContext: uploader_name: str | None tags: list[str] - input_data_rows: list[Row] - output_file_rows: list[Row] - evaluation_rows: list[Row] + input_data_rows: list[UntypedRow] + output_file_rows: list[UntypedRow] + evaluation_rows: list[UntypedRow] task_type: str | None task_evaluation_measure: str | None - setup: Row | None - parameter_rows: list[Row] + setup: UntypedRow | None + parameter_rows: list[UntypedRow] async def _load_run_context( - run: Row, + run: UntypedRow, run_id: int, expdb: AsyncConnection, userdb: AsyncConnection, @@ -99,8 +97,7 @@ async def _load_run_context( setup, parameter_rows, ) = cast( - "tuple[Any, list[str], list[Row], list[Row], list[Row], str | None, str |" - "None, Row | None, list[Row]]", + "tuple[Any, list[str], list[UntypedRow], list[UntypedRow], list[UntypedRow], str | None, str | None, UntypedRow | None, list[UntypedRow]]", # noqa: E501 await asyncio.gather( database.users.get_user(user_id=run.uploader, connection=userdb), database.runs.get_tags(run_id, expdb), @@ -126,7 +123,7 @@ async def _load_run_context( ) -def _build_evaluations(rows: list[Row]) -> list[EvaluationScore]: +def _build_evaluations(rows: list[UntypedRow]) -> list[EvaluationScore]: def _normalise_value(v: object) -> object: if isinstance(v, (int, float)): return int(v) if float(v).is_integer() else float(v) @@ -181,7 +178,7 @@ async def get_run( setup_id=run.setup, setup_string=ctx.setup.setup_string if ctx.setup else None, parameter_setting=[ - ParameterSetting(name=p["name"], value=p["value"], component=p["flow_id"]) + ParameterSetting(name=p.name, value=p.value, component=p.flow_id) for p in ctx.parameter_rows ], error_message=error_messages, diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index ccdf9ff..dbb6581 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -16,6 +16,7 @@ StudyPrivateError, ) from core.formatting import _str_to_bool +from database.schema.base import UntypedRow from database.users import User from routers.dependencies import expdb_connection, fetch_user, fetch_user_or_raise from routers.types import Identifier @@ -23,7 +24,6 @@ from schemas.study import CreateStudy, Study, StudyStatus, StudyType if TYPE_CHECKING: - from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection router = APIRouter(prefix="/studies", tags=["studies"]) @@ -33,7 +33,7 @@ async def _get_study_raise_otherwise( id_or_alias: Identifier | str, user: User | None, expdb: AsyncConnection, -) -> Row: +) -> UntypedRow: search_by_id = isinstance(id_or_alias, int) or id_or_alias.isdigit() if search_by_id: study = await database.studies.get_by_id(int(id_or_alias), expdb) @@ -67,7 +67,7 @@ async def attach_to_study( study_id: Annotated[Identifier, Body()], entity_ids: Annotated[list[Identifier], Body()], user: Annotated[User, Depends(fetch_user_or_raise)], - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> AttachDetachResponse: assert expdb is not None # noqa: S101 if user is None: @@ -90,16 +90,21 @@ async def attach_to_study( # We let the database handle the constraints on whether # the entity is already attached or if it even exists. - attach_kwargs = { - "study_id": study_id, - "user": user, - "connection": expdb, - } try: if study.type_ == StudyType.TASK: - await database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs) + await database.studies.attach_tasks( + task_ids=entity_ids, + study_id=study_id, + user=user, + connection=expdb, + ) else: - await database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs) + await database.studies.attach_runs( + run_ids=entity_ids, + study_id=study_id, + user=user, + connection=expdb, + ) except ValueError as e: msg = str(e) raise StudyConflictError(msg) from e @@ -116,7 +121,7 @@ async def attach_to_study( async def create_study( study: CreateStudy, user: Annotated[User, Depends(fetch_user_or_raise)], - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["study_id"], int]: assert expdb is not None # noqa: S101 if study.main_entity_type == StudyType.RUN and study.tasks: @@ -152,8 +157,8 @@ async def create_study( @router.get("/{alias_or_id}") async def get_study( alias_or_id: Identifier | str, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], user: Annotated[User | None, Depends(fetch_user)] = None, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> Study: assert expdb is not None # noqa: S101 study = await _get_study_raise_otherwise(alias_or_id, user, expdb) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index f68c4b8..9a2b0df 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -26,7 +26,8 @@ from schemas.datasets.openml import Task if TYPE_CHECKING: - from sqlalchemy.engine import RowMapping + from collections.abc import Mapping + from sqlalchemy.ext.asyncio import AsyncConnection router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -70,7 +71,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: async def fill_template( template: str, - task: RowMapping, + task: Mapping[str, Any], task_inputs: dict[str, str | int], connection: AsyncConnection, ) -> dict[str, JSON]: @@ -137,7 +138,7 @@ async def fill_template( async def _fill_json_template( # noqa: C901 template: JSON, - task: RowMapping, + task: Mapping[str, Any], task_inputs: dict[str, str | int], fetched_data: dict[str, str], connection: AsyncConnection, @@ -192,7 +193,7 @@ async def _fill_json_template( # noqa: C901 template = template.replace(match.group(), fetched_data[field]) # I believe that the operations below are always part of string output, so # we don't need to be careful to avoid losing typedness - template = template.replace("[TASK:id]", str(task.task_id)) + template = template.replace("[TASK:id]", str(task["task_id"])) url = get_config().routing.server_url server_url = f"{url.scheme}://{url.host}:{url.port}/" return template.replace("[CONSTANT:base_url]", server_url) @@ -257,6 +258,7 @@ def _quality_clause(quality: str, range_: str | None) -> str: @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], pagination: Annotated[Pagination, Body(default_factory=Pagination)], task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None, tag: Annotated[TagString | None, Body()] = None, @@ -275,7 +277,6 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 number_features: Annotated[IntegerRange | None, Body()] = None, number_classes: Annotated[IntegerRange | None, Body()] = None, number_missing_values: Annotated[IntegerRange | None, Body()] = None, - expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" assert expdb is not None # noqa: S101 @@ -472,7 +473,10 @@ async def get_task( (name, template) for name, io, required, template in templates if io == "input" ] filled_templates = await asyncio.gather( - *[fill_template(template, task, task_inputs, expdb) for name, template in input_templates], + *[ + fill_template(template, dict(task), task_inputs, expdb) + for name, template in input_templates + ], ) inputs = [ filled | {"name": name} diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py index d566403..af841f6 100644 --- a/tests/routers/openml/runs_get_test.py +++ b/tests/routers/openml/runs_get_test.py @@ -299,7 +299,7 @@ def __init__( self.fold = fold rows = [MockRow("test_metric", input_value, repeat=repeat, fold=fold)] - evals = _build_evaluations(rows) + evals = _build_evaluations(rows) # type: ignore[arg-type] assert len(evals) == 1 assert evals[0].value == expected_value From a116199927ec0705df7686f0f7b7d432da0ddc09 Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 2 Jul 2026 08:32:53 +0200 Subject: [PATCH 5/8] Use ORM for TaskTag --- src/database/tasks.py | 40 ++++++++------------------- src/routers/dependencies.py | 8 ++++++ src/routers/openml/tasks.py | 27 ++++++++---------- tests/conftest.py | 22 +++++++++++++-- tests/routers/openml/task_tag_test.py | 39 ++++++++++---------------- 5 files changed, 64 insertions(+), 72 deletions(-) diff --git a/src/database/tasks.py b/src/database/tasks.py index 0910de4..15d6589 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from sqlalchemy import text +from sqlalchemy import select, text from sqlalchemy.exc import IntegrityError from database.exceptions import ( @@ -11,10 +11,11 @@ ForeignKeyConstraintError, ) from database.schema.base import UntypedRow +from database.schema.tags import TaskTag from routers.types import Identifier, TagString if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None: @@ -148,19 +149,10 @@ async def get_task_type_inout_with_template( return rows.all() -async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]: - rows = await connection.execute( - text( - """ - SELECT `tag` - FROM task_tag - WHERE `id` = :task_id - """, - ), - parameters={"task_id": id_}, - ) - tag_rows = rows.all() - return [row.tag for row in tag_rows] +async def get_tags(task_id: Identifier, session: AsyncSession) -> list[TagString]: + stmt = select(TaskTag).where(TaskTag.entity_id == task_id) + tags = (await session.scalars(stmt)).all() + return [t.tag for t in tags] async def tag( @@ -168,22 +160,12 @@ async def tag( tag_: TagString, *, user_id: Identifier, - connection: AsyncConnection, + session: AsyncSession, ) -> None: try: - await connection.execute( - text( - """ - INSERT INTO task_tag(`id`, `tag`, `uploader`) - VALUES (:task_id, :tag, :user_id) - """, - ), - parameters={ - "task_id": id_, - "user_id": user_id, - "tag": tag_, - }, - ) + tag = TaskTag(entity_id=id_, uploader_id=user_id, tag=tag_) + session.add(tag) + await session.flush() except IntegrityError as e: if e.orig is None: raise diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index ca4a965..832c958 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,9 +1,11 @@ +import contextlib from collections.abc import AsyncGenerator, AsyncIterator from typing import TYPE_CHECKING, Annotated from fastapi import Depends from loguru import logger from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.setup import expdb_database, user_database @@ -25,6 +27,12 @@ async def userdb_connection() -> AsyncIterator[AsyncConnection]: yield connection +async def expdb_session() -> AsyncIterator[AsyncSession]: + conn = contextlib.asynccontextmanager(expdb_connection) + async with conn() as connection, AsyncSession(connection) as session, session.begin(): + yield session + + async def fetch_user( api_key: APIKey | None = None, user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 9a2b0df..68aa47f 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -14,8 +14,9 @@ from config import get_config from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError +from database.schema.base import UntypedRow from database.users import User -from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise +from routers.dependencies import Pagination, expdb_connection, expdb_session, fetch_user_or_raise from routers.types import ( CasualString128, Identifier, @@ -26,9 +27,7 @@ from schemas.datasets.openml import Task if TYPE_CHECKING: - from collections.abc import Mapping - - from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -40,10 +39,10 @@ async def tag_task( task_id: Annotated[Identifier, Body()], tag: Annotated[TagString, Body()], user: Annotated[User, Depends(fetch_user_or_raise)], - expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)], + expdb_session: Annotated[AsyncSession, Depends(expdb_session)], ) -> dict[str, dict[str, Any]]: try: - await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db) + await database.tasks.tag(task_id, tag, user_id=user.user_id, session=expdb_session) except ForeignKeyConstraintError: msg = f"Task {task_id} not found." raise TaskNotFoundError(msg, code=472) from None @@ -53,7 +52,7 @@ async def tag_task( logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag) - tags = await database.tasks.get_tags(task_id, expdb_db) + tags = await database.tasks.get_tags(task_id, expdb_session) return { "task_tag": {"id": str(task_id), "tag": tags}, @@ -71,7 +70,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: async def fill_template( template: str, - task: Mapping[str, Any], + task: UntypedRow, task_inputs: dict[str, str | int], connection: AsyncConnection, ) -> dict[str, JSON]: @@ -138,7 +137,7 @@ async def fill_template( async def _fill_json_template( # noqa: C901 template: JSON, - task: Mapping[str, Any], + task: UntypedRow, task_inputs: dict[str, str | int], fetched_data: dict[str, str], connection: AsyncConnection, @@ -193,7 +192,7 @@ async def _fill_json_template( # noqa: C901 template = template.replace(match.group(), fetched_data[field]) # I believe that the operations below are always part of string output, so # we don't need to be careful to avoid losing typedness - template = template.replace("[TASK:id]", str(task["task_id"])) + template = template.replace("[TASK:id]", str(task.task_id)) url = get_config().routing.server_url server_url = f"{url.scheme}://{url.host}:{url.port}/" return template.replace("[CONSTANT:base_url]", server_url) @@ -452,6 +451,7 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 async def get_task( task_id: int, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], + expdb_session: Annotated[AsyncSession, Depends(expdb_session)], ) -> Task: if not (task := await database.tasks.get(task_id, expdb)): msg = f"Task {task_id} not found." @@ -463,7 +463,7 @@ async def get_task( task_input_rows, ttios, tags = await asyncio.gather( database.tasks.get_input_for_task(task_id, expdb), database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb), - database.tasks.get_tags(task_id, expdb), + database.tasks.get_tags(task_id, expdb_session), ) task_inputs = { row.input: int(row.value) if row.value.isdigit() else row.value for row in task_input_rows @@ -473,10 +473,7 @@ async def get_task( (name, template) for name, io, required, template in templates if io == "input" ] filled_templates = await asyncio.gather( - *[ - fill_template(template, dict(task), task_inputs, expdb) - for name, template in input_templates - ], + *[fill_template(template, task, task_inputs, expdb) for name, template in input_templates], ) inputs = [ filled | {"name": name} diff --git a/tests/conftest.py b/tests/conftest.py index 4191c47..7ba01fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,10 +9,11 @@ import httpx import pytest from _pytest.config import Config # noqa: TC002 used during collection by Pytest -from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest from asgi_lifespan import LifespanManager from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession +import routers.dependencies from config import ( Configuration, DatabaseConfiguration, @@ -27,6 +28,7 @@ from tests.users import OWNER_USER if TYPE_CHECKING: + from _pytest.nodes import Item from fastapi import FastAPI from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine @@ -71,6 +73,16 @@ async def expdb_test() -> AsyncIterator[AsyncConnection]: yield connection +@pytest.fixture +async def expdb_session(expdb_test: AsyncConnection) -> AsyncIterator[AsyncSession]: + # It is possible that this session `commits`, does the connection + # rollback then still take effect? Probably not. + async with AsyncSession(expdb_test) as session: + yield session + # Do we here again need to do some check on whether there is an active transation? + await session.rollback() + + @pytest.fixture async def user_test() -> AsyncIterator[AsyncConnection]: async with automatic_rollback(user_database()) as connection: @@ -85,7 +97,7 @@ async def php_api() -> AsyncIterator[httpx.AsyncClient]: yield client -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", autouse=True) async def app() -> AsyncIterator[FastAPI]: config = Configuration( openml_database=DatabaseConfiguration(database="openml"), @@ -103,7 +115,9 @@ async def app() -> AsyncIterator[FastAPI]: @pytest.fixture async def py_api( - expdb_test: AsyncConnection, user_test: AsyncConnection, app: FastAPI + expdb_test: AsyncConnection, + user_test: AsyncConnection, + app: FastAPI, ) -> AsyncIterator[httpx.AsyncClient]: """Create test client which automatically rolls back database updates on teardown.""" # Using the function-scoped database fixtures automatically benefits the @@ -123,6 +137,8 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]: app.dependency_overrides[expdb_connection] = override_expdb app.dependency_overrides[userdb_connection] = override_userdb + routers.dependencies.expdb_connection = override_expdb + async with httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test", diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index cd58964..d83619d 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -2,11 +2,8 @@ from typing import TYPE_CHECKING import pytest -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError -from database.schema.tags import TaskTag from database.tasks import get_tags from database.users import User from routers.openml.tasks import tag_task @@ -17,7 +14,7 @@ if TYPE_CHECKING: import httpx - from sqlalchemy.ext.asyncio import AsyncConnection + from sqlalchemy.ext.asyncio import AsyncSession @pytest.mark.parametrize( @@ -44,59 +41,51 @@ async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncCli [ADMIN_USER, SOME_USER, OWNER_USER], ids=["administrator", "non-owner", "owner"], ) -async def test_task_tag(user: User, expdb_test: AsyncConnection, task_factory: TaskFactory) -> None: +async def test_task_tag(user: User, expdb_session: AsyncSession, task_factory: TaskFactory) -> None: tag = "test_task_tag" task = await task_factory() - result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_db=expdb_test) + result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_session=expdb_session) assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}} - tags = await get_tags(id_=task.id, connection=expdb_test) + tags = await get_tags(task_id=task.id, session=expdb_session) assert tag in tags @pytest.mark.mut async def test_task_tag_returns_existing_tags( - task_factory: TaskFactory, expdb_test: AsyncConnection + task_factory: TaskFactory, expdb_session: AsyncSession ) -> None: task = await task_factory() - await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_db=expdb_test) - result = await tag_task(task_id=task.id, tag="second", user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_session=expdb_session) + result = await tag_task( + task_id=task.id, tag="second", user=ADMIN_USER, expdb_session=expdb_session + ) assert result == {"task_tag": {"id": str(task.id), "tag": ["first", "second"]}} @pytest.mark.mut async def test_task_tag_fails_if_tag_exists( - expdb_test: AsyncConnection, task_factory: TaskFactory + expdb_session: AsyncSession, task_factory: TaskFactory ) -> None: tag = "fails_if_exist" task = await task_factory() - await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_session=expdb_session) with pytest.raises(TagAlreadyExistsError) as e: - await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_session=expdb_session) assert str(task.id) in e.value.detail assert tag in e.value.detail -async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None: +async def test_task_tag_fails_if_task_does_not_exist(expdb_session: AsyncSession) -> None: task_id = 1_000_000 with pytest.raises(TaskNotFoundError) as e: - await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test) + await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_session=expdb_session) assert str(task_id) in e.value.detail task_not_found_in_tag_endpoint = TASK_NOT_FOUND_DURING_TAG assert e.value.code == task_not_found_in_tag_endpoint -async def test_if_reflection_works(py_api: httpx.Client, expdb_test: AsyncConnection) -> None: - _ = py_api - - tag_stmt = select(TaskTag).where(TaskTag.tag == "OpenML100") - async with AsyncSession(bind=expdb_test) as session: - tags = (await session.scalars(tag_stmt)).all() - tag_count = 100 - assert len(tags) == tag_count - - # -- migration tests -- From 993cbd7a2c09515fc92ccfc23c93e875d0d2c8de Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 2 Jul 2026 08:39:23 +0200 Subject: [PATCH 6/8] Return ORM object --- src/database/tasks.py | 5 ++--- src/routers/openml/tasks.py | 4 ++-- tests/conftest.py | 2 +- tests/routers/openml/task_tag_test.py | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/database/tasks.py b/src/database/tasks.py index 15d6589..9d1d120 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -149,10 +149,9 @@ async def get_task_type_inout_with_template( return rows.all() -async def get_tags(task_id: Identifier, session: AsyncSession) -> list[TagString]: +async def get_tags(task_id: Identifier, session: AsyncSession) -> Sequence[TaskTag]: stmt = select(TaskTag).where(TaskTag.entity_id == task_id) - tags = (await session.scalars(stmt)).all() - return [t.tag for t in tags] + return (await session.scalars(stmt)).all() async def tag( diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 68aa47f..55df619 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -55,7 +55,7 @@ async def tag_task( tags = await database.tasks.get_tags(task_id, expdb_session) return { - "task_tag": {"id": str(task_id), "tag": tags}, + "task_tag": {"id": str(task_id), "tag": [t.tag for t in tags]}, } @@ -496,5 +496,5 @@ async def get_task( task_type=task_type.name, input_=inputs, output=outputs, - tags=tags, + tags=[t.tag for t in tags], ) diff --git a/tests/conftest.py b/tests/conftest.py index 7ba01fd..c5b6a4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ import httpx import pytest from _pytest.config import Config # noqa: TC002 used during collection by Pytest +from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest from asgi_lifespan import LifespanManager from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession @@ -28,7 +29,6 @@ from tests.users import OWNER_USER if TYPE_CHECKING: - from _pytest.nodes import Item from fastapi import FastAPI from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py index d83619d..f6e04b1 100644 --- a/tests/routers/openml/task_tag_test.py +++ b/tests/routers/openml/task_tag_test.py @@ -48,7 +48,7 @@ async def test_task_tag(user: User, expdb_session: AsyncSession, task_factory: T assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}} tags = await get_tags(task_id=task.id, session=expdb_session) - assert tag in tags + assert tag in [t.tag for t in tags] @pytest.mark.mut From 24bf0308830d99d4f685dbb7763e72d8ba2eb1dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jul 2026 06:41:58 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/database/setups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/database/setups.py b/src/database/setups.py index bd3cf33..1eddef6 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -5,13 +5,13 @@ from sqlalchemy import text from sqlalchemy.exc import IntegrityError -from database.schema.base import UntypedRow from database.exceptions import ( _DUPLICATE_ENTRY, _FOREIGN_KEY_CONSTRAINT_FAILED, DuplicatePrimaryKeyError, ForeignKeyConstraintError, ) +from database.schema.base import UntypedRow from routers.types import Identifier, TagString if TYPE_CHECKING: From 9541aab0121749eced9a0c42c4f97b3b6564a8db Mon Sep 17 00:00:00 2001 From: PGijsbers Date: Thu, 2 Jul 2026 08:46:44 +0200 Subject: [PATCH 8/8] Use dependency injection instead of monkey patching --- src/database/setups.py | 2 ++ src/routers/dependencies.py | 8 ++++---- tests/conftest.py | 3 --- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/database/setups.py b/src/database/setups.py index 1eddef6..c78986c 100644 --- a/src/database/setups.py +++ b/src/database/setups.py @@ -107,6 +107,8 @@ async def tag( parameters={"setup_id": setup_id, "tag": tag, "user_id": user_id}, ) except IntegrityError as e: + if e.orig is None: + raise code, msg = e.orig.args if code == _FOREIGN_KEY_CONSTRAINT_FAILED: raise ForeignKeyConstraintError(msg) from e diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index 832c958..2ee8548 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,4 +1,3 @@ -import contextlib from collections.abc import AsyncGenerator, AsyncIterator from typing import TYPE_CHECKING, Annotated @@ -27,9 +26,10 @@ async def userdb_connection() -> AsyncIterator[AsyncConnection]: yield connection -async def expdb_session() -> AsyncIterator[AsyncSession]: - conn = contextlib.asynccontextmanager(expdb_connection) - async with conn() as connection, AsyncSession(connection) as session, session.begin(): +async def expdb_session( + connection: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> AsyncIterator[AsyncSession]: + async with AsyncSession(connection) as session, session.begin(): yield session diff --git a/tests/conftest.py b/tests/conftest.py index c5b6a4e..68a8e1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,6 @@ from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession -import routers.dependencies from config import ( Configuration, DatabaseConfiguration, @@ -137,8 +136,6 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]: app.dependency_overrides[expdb_connection] = override_expdb app.dependency_overrides[userdb_connection] = override_userdb - routers.dependencies.expdb_connection = override_expdb - async with httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test",