From 42ecc5566800387e374aa88a831f7b8286db46e6 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Fri, 26 Jun 2026 11:31:51 +0200
Subject: [PATCH 1/8] prototype using hybrid model for orm
---
.pre-commit-config.yaml | 1 +
src/database/setup.py | 30 +++++++++++++++++++++++++++
src/main.py | 4 +++-
tests/routers/openml/task_tag_test.py | 19 +++++++++++++++++
4 files changed, 53 insertions(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 1601d30..4d9aefd 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -24,6 +24,7 @@ repos:
additional_dependencies:
- fastapi
- pytest
+ - sqlalchemy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.15.18'
diff --git a/src/database/setup.py b/src/database/setup.py
index cfd2306..5d37e79 100644
--- a/src/database/setup.py
+++ b/src/database/setup.py
@@ -3,10 +3,33 @@
from loguru import logger
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
+from sqlalchemy.ext.declarative import DeferredReflection
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from config import DatabaseConfiguration, get_config
+class Base(DeclarativeBase):
+ pass
+
+
+class ExpDBReflected(DeferredReflection):
+ __abstract__ = True
+
+
+class UserDBReflected(DeferredReflection):
+ __abstract__ = True
+
+
+class UserR(UserDBReflected, Base):
+ __tablename__ = "users"
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+
+class TaskTag(ExpDBReflected, Base):
+ __tablename__ = "task_tag"
+
+
def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine:
db_url = URL.create(
drivername=db_config.drivername,
@@ -35,6 +58,13 @@ def expdb_database() -> AsyncEngine:
return _create_engine(get_config().expdb_database)
+async def reflect_db_schemas() -> None:
+ async with user_database().connect() as connection:
+ await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it.
+ async with expdb_database().connect() as connection:
+ await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above.
+
+
async def close_databases() -> None:
"""Close all database connections."""
for db in (user_database, expdb_database):
diff --git a/src/main.py b/src/main.py
index b4f7dc8..eb0a495 100644
--- a/src/main.py
+++ b/src/main.py
@@ -26,7 +26,7 @@
request_response_logger,
setup_log_sinks,
)
-from database.setup import close_databases
+from database.setup import close_databases, reflect_db_schemas
from routers.openml.datasets import router as datasets_router
from routers.openml.estimation_procedure import router as estimationprocedure_router
from routers.openml.evaluations import router as evaluationmeasures_router
@@ -45,6 +45,8 @@ async def lifespan(
app: FastAPI | None, # noqa: ARG001 # parameter required by FastAPI/Starlette
) -> AsyncIterator[None]:
"""Manage application lifespan - startup and shutdown events."""
+ logger.info("Reflecting database schemas")
+ await reflect_db_schemas()
yield
await asyncio.gather(
logger.complete(),
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index 087f76e..be392bd 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -2,8 +2,11 @@
from typing import TYPE_CHECKING
import pytest
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError
+from database.setup import TaskTag, UserR
from database.tasks import get_tags
from database.users import User
from routers.openml.tasks import tag_task
@@ -84,6 +87,22 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection
assert e.value.code == task_not_found_in_tag_endpoint
+async def test_if_reflection_works(
+ py_api: httpx.Client, expdb_test: AsyncConnection, user_test: AsyncConnection
+) -> None:
+ _ = py_api
+ user_stmt = select(UserR).where(UserR.id == 1)
+ async with AsyncSession(bind=user_test) as session:
+ user = (await session.scalars(user_stmt)).first()
+ assert user.username == "emily12"
+
+ tag_stmt = select(TaskTag).where(TaskTag.tag == "OpenML100")
+ async with AsyncSession(bind=expdb_test) as session:
+ tags = (await session.scalars(tag_stmt)).all()
+ tag_count = 100
+ assert len(tags) == tag_count
+
+
# -- migration tests --
From 339390a2e4bd3d386f35103738ce59e471738b13 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Tue, 30 Jun 2026 13:23:15 +0200
Subject: [PATCH 2/8] Create the TaskTag ORM class, re-organize the reflection
code
---
src/database/schema/__init__.py | 1 +
src/database/schema/base.py | 47 +++++++++++++++++++++++++++
src/database/schema/tags.py | 30 +++++++++++++++++
src/database/setup.py | 30 -----------------
src/main.py | 3 +-
tests/routers/openml/task_tag_test.py | 10 ++----
6 files changed, 82 insertions(+), 39 deletions(-)
create mode 100644 src/database/schema/__init__.py
create mode 100644 src/database/schema/base.py
create mode 100644 src/database/schema/tags.py
diff --git a/src/database/schema/__init__.py b/src/database/schema/__init__.py
new file mode 100644
index 0000000..6e6260c
--- /dev/null
+++ b/src/database/schema/__init__.py
@@ -0,0 +1 @@
+"""Defines Object-Relational Mappings (ORM)."""
diff --git a/src/database/schema/base.py b/src/database/schema/base.py
new file mode 100644
index 0000000..498dfb5
--- /dev/null
+++ b/src/database/schema/base.py
@@ -0,0 +1,47 @@
+"""Base classes for all ORM classes.
+
+When defining a new ORM class, use both `Base` and one of the `DeferredReflection` subclasses to
+make sure that the class is populated with attributes that may not be defined explicitly.
+For example, when creating a new mapping for a table from the `openml_expdb` database, use:
+
+class ClassName(ExpDBReflected, Base):
+ __tablename__ = "class_names"
+
+ # any columns you wanted mapped explicitly
+ ...
+
+"""
+
+from sqlalchemy.ext.declarative import DeferredReflection
+from sqlalchemy.orm import DeclarativeBase
+
+from database.setup import expdb_database, user_database
+
+
+class Base(DeclarativeBase):
+ """Base class for all ORM classes."""
+
+
+class ExpDBReflected(DeferredReflection):
+ """Base class for ORM classes to map onto a table in the `openml_expdb` database."""
+
+ __abstract__ = True
+
+
+class UserDBReflected(DeferredReflection):
+ """Base class for ORM classes to map onto a table in the `openml` database."""
+
+ __abstract__ = True
+
+
+async def reflect_db_schemas() -> None:
+ """Populate defined ORM classes with attributes defined from columns in the database.
+
+ For example, the `dataset` class would automatically get a `collection_date` attribute,
+ even if it wasn't explicitly declared in the class definition,
+ because the `openml_expdb.dataset` table has a column `collection_date`.
+ """
+ async with user_database().connect() as connection:
+ await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it.
+ async with expdb_database().connect() as connection:
+ await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above.
diff --git a/src/database/schema/tags.py b/src/database/schema/tags.py
new file mode 100644
index 0000000..a6b6a78
--- /dev/null
+++ b/src/database/schema/tags.py
@@ -0,0 +1,30 @@
+"""ORM classes for the *_tag tables (task_tag, ...)."""
+
+from datetime import datetime
+
+from sqlalchemy import FetchedValue
+from sqlalchemy.orm import Mapped, mapped_column
+
+from database.schema.base import Base, ExpDBReflected
+from routers.types import Identifier, TagString
+
+
+class Tag:
+ """Base class for all of the *_tag tables."""
+
+ # The identifier of the entity that is tagged (e.g., dataset id, task id)
+ entity_id: Mapped[Identifier] = mapped_column("id", primary_key=True)
+ tag: Mapped[TagString] = mapped_column(primary_key=True)
+ uploader_id: Mapped[Identifier] = mapped_column("uploader")
+ creation_date: Mapped[datetime] = mapped_column("date", server_default=FetchedValue())
+
+
+class TaskTag(ExpDBReflected, Tag, Base):
+ """Tags belonging to a task."""
+
+ __tablename__ = "task_tag"
+
+ @property
+ def task_id(self) -> Identifier:
+ """Identifier of the task which is tagged by this tag."""
+ return self.entity_id
diff --git a/src/database/setup.py b/src/database/setup.py
index 5d37e79..cfd2306 100644
--- a/src/database/setup.py
+++ b/src/database/setup.py
@@ -3,33 +3,10 @@
from loguru import logger
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
-from sqlalchemy.ext.declarative import DeferredReflection
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from config import DatabaseConfiguration, get_config
-class Base(DeclarativeBase):
- pass
-
-
-class ExpDBReflected(DeferredReflection):
- __abstract__ = True
-
-
-class UserDBReflected(DeferredReflection):
- __abstract__ = True
-
-
-class UserR(UserDBReflected, Base):
- __tablename__ = "users"
- id: Mapped[int] = mapped_column(primary_key=True)
-
-
-class TaskTag(ExpDBReflected, Base):
- __tablename__ = "task_tag"
-
-
def _create_engine(db_config: DatabaseConfiguration) -> AsyncEngine:
db_url = URL.create(
drivername=db_config.drivername,
@@ -58,13 +35,6 @@ def expdb_database() -> AsyncEngine:
return _create_engine(get_config().expdb_database)
-async def reflect_db_schemas() -> None:
- async with user_database().connect() as connection:
- await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it.
- async with expdb_database().connect() as connection:
- await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above.
-
-
async def close_databases() -> None:
"""Close all database connections."""
for db in (user_database, expdb_database):
diff --git a/src/main.py b/src/main.py
index eb0a495..6207b15 100644
--- a/src/main.py
+++ b/src/main.py
@@ -26,7 +26,8 @@
request_response_logger,
setup_log_sinks,
)
-from database.setup import close_databases, reflect_db_schemas
+from database.schema.base import reflect_db_schemas
+from database.setup import close_databases
from routers.openml.datasets import router as datasets_router
from routers.openml.estimation_procedure import router as estimationprocedure_router
from routers.openml.evaluations import router as evaluationmeasures_router
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index be392bd..cd58964 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -6,7 +6,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError
-from database.setup import TaskTag, UserR
+from database.schema.tags import TaskTag
from database.tasks import get_tags
from database.users import User
from routers.openml.tasks import tag_task
@@ -87,14 +87,8 @@ async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection
assert e.value.code == task_not_found_in_tag_endpoint
-async def test_if_reflection_works(
- py_api: httpx.Client, expdb_test: AsyncConnection, user_test: AsyncConnection
-) -> None:
+async def test_if_reflection_works(py_api: httpx.Client, expdb_test: AsyncConnection) -> None:
_ = py_api
- user_stmt = select(UserR).where(UserR.id == 1)
- async with AsyncSession(bind=user_test) as session:
- user = (await session.scalars(user_stmt)).first()
- assert user.username == "emily12"
tag_stmt = select(TaskTag).where(TaskTag.tag == "OpenML100")
async with AsyncSession(bind=expdb_test) as session:
From cc00677e3af36161aca0b4316f61f70fb5861101 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Tue, 30 Jun 2026 13:54:48 +0200
Subject: [PATCH 3/8] Fix typing of _database property
---
src/database/users.py | 5 ++++-
tests/users.py | 8 ++++----
2 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/src/database/users.py b/src/database/users.py
index 381f01d..26a5f3d 100644
--- a/src/database/users.py
+++ b/src/database/users.py
@@ -106,9 +106,9 @@ async def get_user_groups_for(
@dataclasses.dataclass
class User:
user_id: Identifier
- _database: AsyncConnection
first_name: str = ""
last_name: str = ""
+ _database: AsyncConnection | None = None
_groups: list[UserGroup] | None = None
def __post_init__(self) -> None:
@@ -136,6 +136,9 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None:
return None
async def get_groups(self) -> list[UserGroup]:
+ if self._database is None:
+ msg = "`get_groups` can only be used when `connection` is provided on instantiation."
+ raise RuntimeError(msg)
if self._groups is None:
self._groups = await get_user_groups_for(
user_id=self.user_id,
diff --git a/tests/users.py b/tests/users.py
index c98ffb0..07d3c66 100644
--- a/tests/users.py
+++ b/tests/users.py
@@ -3,10 +3,10 @@
from database.users import User, UserGroup
NO_USER = None
-SOME_USER = User(user_id=2, _database=None, _groups=[UserGroup.READ_WRITE])
-OWNER_USER = User(user_id=3229, _database=None, _groups=[UserGroup.READ_WRITE])
-DATASET_130_OWNER = User(user_id=16, _database=None, _groups=[UserGroup.READ_WRITE])
-ADMIN_USER = User(user_id=1159, _database=None, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE])
+SOME_USER = User(user_id=2, _groups=[UserGroup.READ_WRITE])
+OWNER_USER = User(user_id=3229, _groups=[UserGroup.READ_WRITE])
+DATASET_130_OWNER = User(user_id=16, _groups=[UserGroup.READ_WRITE])
+ADMIN_USER = User(user_id=1159, _groups=[UserGroup.ADMIN, UserGroup.READ_WRITE])
class ApiKey(StrEnum):
From 0b201ee35d346e6c78685e6e790654640bbaf112 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Tue, 30 Jun 2026 15:17:38 +0200
Subject: [PATCH 4/8] Fix type errors
---
src/core/formatting.py | 9 ++----
src/database/datasets.py | 14 +++++----
src/database/evaluations.py | 15 +++++-----
src/database/flows.py | 27 +++++++++--------
src/database/runs.py | 24 +++++++--------
src/database/schema/base.py | 5 ++++
src/database/setups.py | 7 +++--
src/database/studies.py | 20 +++++--------
src/database/tasks.py | 42 ++++++++++++---------------
src/database/users.py | 13 +++++----
src/routers/openml/datasets.py | 9 +++---
src/routers/openml/runs.py | 23 +++++++--------
src/routers/openml/study.py | 29 ++++++++++--------
src/routers/openml/tasks.py | 16 ++++++----
tests/routers/openml/runs_get_test.py | 2 +-
15 files changed, 128 insertions(+), 127 deletions(-)
diff --git a/src/core/formatting.py b/src/core/formatting.py
index faf6d42..a4b09e8 100644
--- a/src/core/formatting.py
+++ b/src/core/formatting.py
@@ -1,12 +1,9 @@
import html
-from typing import TYPE_CHECKING
from config import get_config
+from database.schema.base import UntypedRow
from schemas.datasets.openml import DatasetFileFormat
-if TYPE_CHECKING:
- from sqlalchemy.engine import Row
-
def _str_to_bool(string: str) -> bool:
if string.casefold() in ["true", "1", "yes", "y"]:
@@ -17,7 +14,7 @@ def _str_to_bool(string: str) -> bool:
raise ValueError(msg)
-def _format_parquet_url(dataset: Row) -> str | None:
+def _format_parquet_url(dataset: UntypedRow) -> str | None:
if dataset.format.lower() != DatasetFileFormat.ARFF:
return None
@@ -27,7 +24,7 @@ def _format_parquet_url(dataset: Row) -> str | None:
return f"{minio_base_url}datasets/{ten_thousands_prefix}/{padded_id}/dataset_{dataset.did}.pq"
-def _format_dataset_url(dataset: Row) -> str:
+def _format_dataset_url(dataset: UntypedRow) -> str:
base_url = get_config().routing.server_url
filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}"
return f"{base_url}data/v1/download/{dataset.file_id}/{filename}"
diff --git a/src/database/datasets.py b/src/database/datasets.py
index 5f9c0e5..5cb99d0 100644
--- a/src/database/datasets.py
+++ b/src/database/datasets.py
@@ -13,15 +13,15 @@
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
+from database.schema.base import UntypedRow
from routers.types import Identifier, TagString
from schemas.datasets.openml import DatasetStatus, Feature
if TYPE_CHECKING:
- from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
+async def get(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
@@ -35,7 +35,7 @@ async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
return row.one_or_none()
-async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None:
+async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
@@ -53,7 +53,7 @@ async def get_tag(
dataset_id: Identifier,
tag: TagString,
connection: AsyncConnection,
-) -> Row | None:
+) -> UntypedRow | None:
return (
await connection.execute(
text(
@@ -111,6 +111,8 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
},
)
except IntegrityError as e:
+ if e.orig is None:
+ raise
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
@@ -122,7 +124,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
async def get_description(
id_: Identifier,
connection: AsyncConnection,
-) -> Row | None:
+) -> UntypedRow | None:
"""Get the most recent description for the dataset."""
row = await connection.execute(
text(
@@ -160,7 +162,7 @@ async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetSta
async def get_latest_processing_update(
dataset_id: Identifier,
connection: AsyncConnection,
-) -> Row | None:
+) -> UntypedRow | None:
row = await connection.execute(
text(
"""
diff --git a/src/database/evaluations.py b/src/database/evaluations.py
index 382653f..14bec4d 100644
--- a/src/database/evaluations.py
+++ b/src/database/evaluations.py
@@ -1,16 +1,20 @@
from collections.abc import Sequence
-from typing import TYPE_CHECKING, cast
+from typing import TYPE_CHECKING
-from sqlalchemy import Row, text
+from sqlalchemy import text
from core.formatting import _str_to_bool
+from database.schema.base import UntypedRow
from schemas.datasets.openml import EstimationProcedure
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
+async def get_math_functions(
+ function_type: str,
+ connection: AsyncConnection,
+) -> Sequence[UntypedRow]:
rows = await connection.execute(
text(
"""
@@ -21,10 +25,7 @@ async def get_math_functions(function_type: str, connection: AsyncConnection) ->
),
parameters={"function_type": function_type},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]:
diff --git a/src/database/flows.py b/src/database/flows.py
index 7b1da5e..830a386 100644
--- a/src/database/flows.py
+++ b/src/database/flows.py
@@ -1,15 +1,16 @@
from collections.abc import Sequence
-from typing import TYPE_CHECKING, cast
+from typing import TYPE_CHECKING
-from sqlalchemy import Row, text
+from sqlalchemy import text
+from database.schema.base import UntypedRow
from routers.types import Identifier
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
@@ -20,10 +21,7 @@ async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence
),
parameters={"flow_id": for_flow},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
@@ -41,7 +39,7 @@ async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
return [tag.tag for tag in tag_rows]
-async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
@@ -52,13 +50,14 @@ async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequenc
),
parameters={"flow_id": flow_id},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
-async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
+async def get_by_name(
+ name: str,
+ external_version: str,
+ expdb: AsyncConnection,
+) -> UntypedRow | None:
"""Get flow by name and external version."""
row = await expdb.execute(
text(
@@ -73,7 +72,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection)
return row.one_or_none()
-async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
+async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
row = await expdb.execute(
text(
"""
diff --git a/src/database/runs.py b/src/database/runs.py
index f0bc419..0c5c93c 100644
--- a/src/database/runs.py
+++ b/src/database/runs.py
@@ -3,8 +3,9 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
-from sqlalchemy import Row, bindparam, text
+from sqlalchemy import bindparam, text
+from database.schema.base import UntypedRow
from routers.types import Identifier
if TYPE_CHECKING:
@@ -26,7 +27,7 @@ async def exist(id_: Identifier, expdb: AsyncConnection) -> bool:
return bool(row.one_or_none())
-async def get(run_id: Identifier, expdb: AsyncConnection) -> Row | None:
+async def get(run_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
"""Fetch the core run row from the `run` table.
Returns the row if found, or None if no run with `run_id` exists.
@@ -63,7 +64,7 @@ async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]:
return [row.tag for row in rows.all()]
-async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
+async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]:
"""Fetch the dataset(s) used as input for a run, with name and url.
Joins `input_data` with `dataset` to include the dataset name and ARFF URL.
@@ -79,10 +80,10 @@ async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
),
parameters={"run_id": run_id},
)
- return cast("list[Row]", rows.all())
+ return cast("list[UntypedRow]", rows.all())
-async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
+async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]:
"""Fetch output files attached to a run from the `runfile` table.
Typical entries include the description XML and predictions ARFF.
@@ -98,7 +99,7 @@ async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
),
parameters={"run_id": run_id},
)
- return cast("list[Row]", rows.all())
+ return cast("list[UntypedRow]", rows.all())
async def get_evaluations(
@@ -106,7 +107,7 @@ async def get_evaluations(
expdb: AsyncConnection,
*,
evaluation_engine_ids: list[int],
-) -> list[Row]:
+) -> list[UntypedRow]:
"""Fetch evaluation metric results for a run.
Joins `evaluation` with `math_function` to resolve the metric name
@@ -138,10 +139,10 @@ async def get_evaluations(
query,
parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids},
)
- return cast("list[Row]", rows.all())
+ return cast("list[UntypedRow]", rows.all())
-async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[UntypedRow]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
text(
@@ -153,7 +154,4 @@ async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
),
parameters={"run_id": run_id},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
diff --git a/src/database/schema/base.py b/src/database/schema/base.py
index 498dfb5..0a75055 100644
--- a/src/database/schema/base.py
+++ b/src/database/schema/base.py
@@ -12,11 +12,16 @@ class ClassName(ExpDBReflected, Base):
"""
+from typing import Any
+
+from sqlalchemy import Row
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import DeclarativeBase
from database.setup import expdb_database, user_database
+UntypedRow = Row[Any]
+
class Base(DeclarativeBase):
"""Base class for all ORM classes."""
diff --git a/src/database/setups.py b/src/database/setups.py
index 1c959b3..1cf6c22 100644
--- a/src/database/setups.py
+++ b/src/database/setups.py
@@ -4,14 +4,15 @@
from sqlalchemy import text
+from database.schema.base import UntypedRow
from routers.types import Identifier, TagString
if TYPE_CHECKING:
- from sqlalchemy.engine import Row, RowMapping
+ from sqlalchemy.engine import RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get(setup_id: Identifier, connection: AsyncConnection) -> Row | None:
+async def get(setup_id: Identifier, connection: AsyncConnection) -> UntypedRow | None:
"""Get the setup with id `setup_id` from the database."""
row = await connection.execute(
text(
@@ -53,7 +54,7 @@ async def get_parameters(setup_id: Identifier, connection: AsyncConnection) -> l
return list(rows.mappings().all())
-async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[Row]:
+async def get_tags(setup_id: Identifier, connection: AsyncConnection) -> list[UntypedRow]:
"""Get all tags for setup with `setup_id` from the database."""
rows = await connection.execute(
text(
diff --git a/src/database/studies.py b/src/database/studies.py
index 1939452..b4a6936 100644
--- a/src/database/studies.py
+++ b/src/database/studies.py
@@ -3,8 +3,9 @@
from datetime import UTC, datetime
from typing import TYPE_CHECKING, cast
-from sqlalchemy import Row, text
+from sqlalchemy import text
+from database.schema.base import UntypedRow
from database.users import User
from routers.types import Identifier
from schemas.study import CreateStudy, StudyType
@@ -13,7 +14,7 @@
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None:
+async def get_by_id(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
@@ -27,7 +28,7 @@ async def get_by_id(id_: Identifier, connection: AsyncConnection) -> Row | None:
return row.one_or_none()
-async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None:
+async def get_by_alias(alias: str, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
@@ -41,7 +42,7 @@ async def get_by_alias(alias: str, connection: AsyncConnection) -> Row | None:
return row.one_or_none()
-async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_study_data(study: UntypedRow, expdb: AsyncConnection) -> Sequence[UntypedRow]:
"""Return data related to the study, content depends on the study type.
For task studies: (task id, dataset id)
@@ -58,10 +59,8 @@ async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]:
),
parameters={"study_id": study.id},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
+
rows = await expdb.execute(
text(
"""
@@ -80,10 +79,7 @@ async def get_study_data(study: Row, expdb: AsyncConnection) -> Sequence[Row]:
),
parameters={"study_id": study.id},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
async def create(study: CreateStudy, user: User, expdb: AsyncConnection) -> int:
diff --git a/src/database/tasks.py b/src/database/tasks.py
index b0f6010..0910de4 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -1,7 +1,7 @@
from collections.abc import Sequence
-from typing import TYPE_CHECKING, cast
+from typing import TYPE_CHECKING
-from sqlalchemy import Row, text
+from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from database.exceptions import (
@@ -10,13 +10,14 @@
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
+from database.schema.base import UntypedRow
from routers.types import Identifier, TagString
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
-async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
+async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
row = await expdb.execute(
text(
"""
@@ -30,7 +31,7 @@ async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
return row.one_or_none()
-async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]:
+async def get_task_types(expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
@@ -39,13 +40,10 @@ async def get_task_types(expdb: AsyncConnection) -> Sequence[Row]:
""",
),
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
-async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> Row | None:
+async def get_task_type(task_type_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
row = await expdb.execute(
text(
"""
@@ -102,7 +100,10 @@ async def get_task_evaluation_measure(task_id: int, expdb: AsyncConnection) -> s
return result.value if result else None
-async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_input_for_task_type(
+ task_type_id: int,
+ expdb: AsyncConnection,
+) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
@@ -113,13 +114,10 @@ async def get_input_for_task_type(task_type_id: int, expdb: AsyncConnection) ->
),
parameters={"ttid": task_type_id},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
-async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
+async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
@@ -130,16 +128,13 @@ async def get_input_for_task(id_: Identifier, expdb: AsyncConnection) -> Sequenc
),
parameters={"task_id": id_},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
async def get_task_type_inout_with_template(
task_type: Identifier,
expdb: AsyncConnection,
-) -> Sequence[Row]:
+) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
@@ -150,10 +145,7 @@ async def get_task_type_inout_with_template(
),
parameters={"ttid": task_type},
)
- return cast(
- "Sequence[Row]",
- rows.all(),
- )
+ return rows.all()
async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]:
@@ -193,6 +185,8 @@ async def tag(
},
)
except IntegrityError as e:
+ if e.orig is None:
+ raise
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
diff --git a/src/database/users.py b/src/database/users.py
index 26a5f3d..379c6f8 100644
--- a/src/database/users.py
+++ b/src/database/users.py
@@ -136,14 +136,17 @@ async def fetch(cls, api_key: APIKey, user_db: AsyncConnection) -> Self | None:
return None
async def get_groups(self) -> list[UserGroup]:
+ if self._groups:
+ return self._groups
+
if self._database is None:
msg = "`get_groups` can only be used when `connection` is provided on instantiation."
raise RuntimeError(msg)
- if self._groups is None:
- self._groups = await get_user_groups_for(
- user_id=self.user_id,
- connection=self._database,
- )
+
+ self._groups = await get_user_groups_for(
+ user_id=self.user_id,
+ connection=self._database,
+ )
return self._groups
async def is_admin(self) -> bool:
diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py
index 270c71b..72c94de 100644
--- a/src/routers/openml/datasets.py
+++ b/src/routers/openml/datasets.py
@@ -156,6 +156,7 @@ def _quality_clause(quality: str, range_: str | None) -> str:
@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
@router.get(path="/list")
async def list_datasets( # noqa: PLR0913, C901
+ expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[CasualString128 | None, Body()] = None,
tag: Annotated[TagString | None, Body()] = None,
@@ -180,9 +181,7 @@ async def list_datasets( # noqa: PLR0913, C901
number_missing_values: Annotated[IntegerRange | None, Body()] = None,
status: Annotated[DatasetStatusFilter, Body()] = DatasetStatusFilter.ACTIVE,
user: Annotated[User | None, Depends(fetch_user)] = None,
- expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> list[dict[str, Any]]:
- assert expdb_db is not None # noqa: S101
status_subquery = text(
"""
SELECT ds1.`did`, ds1.`status`
@@ -356,8 +355,8 @@ async def _get_dataset_raise_otherwise(
@router.get("/features/{dataset_id}", response_model_exclude_none=True)
async def get_dataset_features(
dataset_id: Identifier,
+ expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
user: Annotated[User | None, Depends(fetch_user)] = None,
- expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> list[Feature]:
assert expdb is not None # noqa: S101
await _get_dataset_raise_otherwise(dataset_id, user, expdb)
@@ -453,9 +452,9 @@ async def update_dataset_status(
)
async def get_dataset(
dataset_id: Identifier,
+ user_db: Annotated[AsyncConnection, Depends(userdb_connection)],
+ expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
user: Annotated[User | None, Depends(fetch_user)] = None,
- user_db: Annotated[AsyncConnection, Depends(userdb_connection)] = None,
- expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> DatasetMetadata:
assert user_db is not None # noqa: S101
assert expdb_db is not None # noqa: S101
diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py
index 0abf40e..88296cf 100644
--- a/src/routers/openml/runs.py
+++ b/src/routers/openml/runs.py
@@ -6,9 +6,6 @@
from fastapi import APIRouter, Depends
-if TYPE_CHECKING:
- from sqlalchemy import Row
-
import config
import database.flows
import database.runs
@@ -16,6 +13,7 @@
import database.tasks
import database.users
from core.errors import RunNotFoundError, RunTraceNotFoundError
+from database.schema.base import UntypedRow
from routers.dependencies import expdb_connection, userdb_connection
from routers.types import Identifier
from schemas.runs import (
@@ -72,17 +70,17 @@ class RunContext:
uploader_name: str | None
tags: list[str]
- input_data_rows: list[Row]
- output_file_rows: list[Row]
- evaluation_rows: list[Row]
+ input_data_rows: list[UntypedRow]
+ output_file_rows: list[UntypedRow]
+ evaluation_rows: list[UntypedRow]
task_type: str | None
task_evaluation_measure: str | None
- setup: Row | None
- parameter_rows: list[Row]
+ setup: UntypedRow | None
+ parameter_rows: list[UntypedRow]
async def _load_run_context(
- run: Row,
+ run: UntypedRow,
run_id: int,
expdb: AsyncConnection,
userdb: AsyncConnection,
@@ -99,8 +97,7 @@ async def _load_run_context(
setup,
parameter_rows,
) = cast(
- "tuple[Any, list[str], list[Row], list[Row], list[Row], str | None, str |"
- "None, Row | None, list[Row]]",
+ "tuple[Any, list[str], list[UntypedRow], list[UntypedRow], list[UntypedRow], str | None, str | None, UntypedRow | None, list[UntypedRow]]", # noqa: E501
await asyncio.gather(
database.users.get_user(user_id=run.uploader, connection=userdb),
database.runs.get_tags(run_id, expdb),
@@ -126,7 +123,7 @@ async def _load_run_context(
)
-def _build_evaluations(rows: list[Row]) -> list[EvaluationScore]:
+def _build_evaluations(rows: list[UntypedRow]) -> list[EvaluationScore]:
def _normalise_value(v: object) -> object:
if isinstance(v, (int, float)):
return int(v) if float(v).is_integer() else float(v)
@@ -181,7 +178,7 @@ async def get_run(
setup_id=run.setup,
setup_string=ctx.setup.setup_string if ctx.setup else None,
parameter_setting=[
- ParameterSetting(name=p["name"], value=p["value"], component=p["flow_id"])
+ ParameterSetting(name=p.name, value=p.value, component=p.flow_id)
for p in ctx.parameter_rows
],
error_message=error_messages,
diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py
index ccdf9ff..dbb6581 100644
--- a/src/routers/openml/study.py
+++ b/src/routers/openml/study.py
@@ -16,6 +16,7 @@
StudyPrivateError,
)
from core.formatting import _str_to_bool
+from database.schema.base import UntypedRow
from database.users import User
from routers.dependencies import expdb_connection, fetch_user, fetch_user_or_raise
from routers.types import Identifier
@@ -23,7 +24,6 @@
from schemas.study import CreateStudy, Study, StudyStatus, StudyType
if TYPE_CHECKING:
- from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection
router = APIRouter(prefix="/studies", tags=["studies"])
@@ -33,7 +33,7 @@ async def _get_study_raise_otherwise(
id_or_alias: Identifier | str,
user: User | None,
expdb: AsyncConnection,
-) -> Row:
+) -> UntypedRow:
search_by_id = isinstance(id_or_alias, int) or id_or_alias.isdigit()
if search_by_id:
study = await database.studies.get_by_id(int(id_or_alias), expdb)
@@ -67,7 +67,7 @@ async def attach_to_study(
study_id: Annotated[Identifier, Body()],
entity_ids: Annotated[list[Identifier], Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
- expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
+ expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> AttachDetachResponse:
assert expdb is not None # noqa: S101
if user is None:
@@ -90,16 +90,21 @@ async def attach_to_study(
# We let the database handle the constraints on whether
# the entity is already attached or if it even exists.
- attach_kwargs = {
- "study_id": study_id,
- "user": user,
- "connection": expdb,
- }
try:
if study.type_ == StudyType.TASK:
- await database.studies.attach_tasks(task_ids=entity_ids, **attach_kwargs)
+ await database.studies.attach_tasks(
+ task_ids=entity_ids,
+ study_id=study_id,
+ user=user,
+ connection=expdb,
+ )
else:
- await database.studies.attach_runs(run_ids=entity_ids, **attach_kwargs)
+ await database.studies.attach_runs(
+ run_ids=entity_ids,
+ study_id=study_id,
+ user=user,
+ connection=expdb,
+ )
except ValueError as e:
msg = str(e)
raise StudyConflictError(msg) from e
@@ -116,7 +121,7 @@ async def attach_to_study(
async def create_study(
study: CreateStudy,
user: Annotated[User, Depends(fetch_user_or_raise)],
- expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
+ expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["study_id"], int]:
assert expdb is not None # noqa: S101
if study.main_entity_type == StudyType.RUN and study.tasks:
@@ -152,8 +157,8 @@ async def create_study(
@router.get("/{alias_or_id}")
async def get_study(
alias_or_id: Identifier | str,
+ expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
user: Annotated[User | None, Depends(fetch_user)] = None,
- expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> Study:
assert expdb is not None # noqa: S101
study = await _get_study_raise_otherwise(alias_or_id, user, expdb)
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index f68c4b8..9a2b0df 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -26,7 +26,8 @@
from schemas.datasets.openml import Task
if TYPE_CHECKING:
- from sqlalchemy.engine import RowMapping
+ from collections.abc import Mapping
+
from sqlalchemy.ext.asyncio import AsyncConnection
router = APIRouter(prefix="/tasks", tags=["tasks"])
@@ -70,7 +71,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
async def fill_template(
template: str,
- task: RowMapping,
+ task: Mapping[str, Any],
task_inputs: dict[str, str | int],
connection: AsyncConnection,
) -> dict[str, JSON]:
@@ -137,7 +138,7 @@ async def fill_template(
async def _fill_json_template( # noqa: C901
template: JSON,
- task: RowMapping,
+ task: Mapping[str, Any],
task_inputs: dict[str, str | int],
fetched_data: dict[str, str],
connection: AsyncConnection,
@@ -192,7 +193,7 @@ async def _fill_json_template( # noqa: C901
template = template.replace(match.group(), fetched_data[field])
# I believe that the operations below are always part of string output, so
# we don't need to be careful to avoid losing typedness
- template = template.replace("[TASK:id]", str(task.task_id))
+ template = template.replace("[TASK:id]", str(task["task_id"]))
url = get_config().routing.server_url
server_url = f"{url.scheme}://{url.host}:{url.port}/"
return template.replace("[CONSTANT:base_url]", server_url)
@@ -257,6 +258,7 @@ def _quality_clause(quality: str, range_: str | None) -> str:
@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
@router.get(path="/list")
async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
+ expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
task_type_id: Annotated[Identifier | None, Body(description="Filter by task type id.")] = None,
tag: Annotated[TagString | None, Body()] = None,
@@ -275,7 +277,6 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
number_features: Annotated[IntegerRange | None, Body()] = None,
number_classes: Annotated[IntegerRange | None, Body()] = None,
number_missing_values: Annotated[IntegerRange | None, Body()] = None,
- expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> list[dict[str, Any]]:
"""List tasks, optionally filtered by type, tag, status, dataset properties, and more."""
assert expdb is not None # noqa: S101
@@ -472,7 +473,10 @@ async def get_task(
(name, template) for name, io, required, template in templates if io == "input"
]
filled_templates = await asyncio.gather(
- *[fill_template(template, task, task_inputs, expdb) for name, template in input_templates],
+ *[
+ fill_template(template, dict(task), task_inputs, expdb)
+ for name, template in input_templates
+ ],
)
inputs = [
filled | {"name": name}
diff --git a/tests/routers/openml/runs_get_test.py b/tests/routers/openml/runs_get_test.py
index d566403..af841f6 100644
--- a/tests/routers/openml/runs_get_test.py
+++ b/tests/routers/openml/runs_get_test.py
@@ -299,7 +299,7 @@ def __init__(
self.fold = fold
rows = [MockRow("test_metric", input_value, repeat=repeat, fold=fold)]
- evals = _build_evaluations(rows)
+ evals = _build_evaluations(rows) # type: ignore[arg-type]
assert len(evals) == 1
assert evals[0].value == expected_value
From a116199927ec0705df7686f0f7b7d432da0ddc09 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 2 Jul 2026 08:32:53 +0200
Subject: [PATCH 5/8] Use ORM for TaskTag
---
src/database/tasks.py | 40 ++++++++-------------------
src/routers/dependencies.py | 8 ++++++
src/routers/openml/tasks.py | 27 ++++++++----------
tests/conftest.py | 22 +++++++++++++--
tests/routers/openml/task_tag_test.py | 39 ++++++++++----------------
5 files changed, 64 insertions(+), 72 deletions(-)
diff --git a/src/database/tasks.py b/src/database/tasks.py
index 0910de4..15d6589 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from sqlalchemy import text
+from sqlalchemy import select, text
from sqlalchemy.exc import IntegrityError
from database.exceptions import (
@@ -11,10 +11,11 @@
ForeignKeyConstraintError,
)
from database.schema.base import UntypedRow
+from database.schema.tags import TaskTag
from routers.types import Identifier, TagString
if TYPE_CHECKING:
- from sqlalchemy.ext.asyncio import AsyncConnection
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
@@ -148,19 +149,10 @@ async def get_task_type_inout_with_template(
return rows.all()
-async def get_tags(id_: Identifier, connection: AsyncConnection) -> list[str]:
- rows = await connection.execute(
- text(
- """
- SELECT `tag`
- FROM task_tag
- WHERE `id` = :task_id
- """,
- ),
- parameters={"task_id": id_},
- )
- tag_rows = rows.all()
- return [row.tag for row in tag_rows]
+async def get_tags(task_id: Identifier, session: AsyncSession) -> list[TagString]:
+ stmt = select(TaskTag).where(TaskTag.entity_id == task_id)
+ tags = (await session.scalars(stmt)).all()
+ return [t.tag for t in tags]
async def tag(
@@ -168,22 +160,12 @@ async def tag(
tag_: TagString,
*,
user_id: Identifier,
- connection: AsyncConnection,
+ session: AsyncSession,
) -> None:
try:
- await connection.execute(
- text(
- """
- INSERT INTO task_tag(`id`, `tag`, `uploader`)
- VALUES (:task_id, :tag, :user_id)
- """,
- ),
- parameters={
- "task_id": id_,
- "user_id": user_id,
- "tag": tag_,
- },
- )
+ tag = TaskTag(entity_id=id_, uploader_id=user_id, tag=tag_)
+ session.add(tag)
+ await session.flush()
except IntegrityError as e:
if e.orig is None:
raise
diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py
index ca4a965..832c958 100644
--- a/src/routers/dependencies.py
+++ b/src/routers/dependencies.py
@@ -1,9 +1,11 @@
+import contextlib
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Annotated
from fastapi import Depends
from loguru import logger
from pydantic import BaseModel, Field
+from sqlalchemy.ext.asyncio import AsyncSession
from core.errors import AuthenticationFailedError, AuthenticationRequiredError
from database.setup import expdb_database, user_database
@@ -25,6 +27,12 @@ async def userdb_connection() -> AsyncIterator[AsyncConnection]:
yield connection
+async def expdb_session() -> AsyncIterator[AsyncSession]:
+ conn = contextlib.asynccontextmanager(expdb_connection)
+ async with conn() as connection, AsyncSession(connection) as session, session.begin():
+ yield session
+
+
async def fetch_user(
api_key: APIKey | None = None,
user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None,
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 9a2b0df..68aa47f 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -14,8 +14,9 @@
from config import get_config
from core.errors import InternalError, NoResultsError, TagAlreadyExistsError, TaskNotFoundError
from database.exceptions import DuplicatePrimaryKeyError, ForeignKeyConstraintError
+from database.schema.base import UntypedRow
from database.users import User
-from routers.dependencies import Pagination, expdb_connection, fetch_user_or_raise
+from routers.dependencies import Pagination, expdb_connection, expdb_session, fetch_user_or_raise
from routers.types import (
CasualString128,
Identifier,
@@ -26,9 +27,7 @@
from schemas.datasets.openml import Task
if TYPE_CHECKING:
- from collections.abc import Mapping
-
- from sqlalchemy.ext.asyncio import AsyncConnection
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
router = APIRouter(prefix="/tasks", tags=["tasks"])
@@ -40,10 +39,10 @@ async def tag_task(
task_id: Annotated[Identifier, Body()],
tag: Annotated[TagString, Body()],
user: Annotated[User, Depends(fetch_user_or_raise)],
- expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
+ expdb_session: Annotated[AsyncSession, Depends(expdb_session)],
) -> dict[str, dict[str, Any]]:
try:
- await database.tasks.tag(task_id, tag, user_id=user.user_id, connection=expdb_db)
+ await database.tasks.tag(task_id, tag, user_id=user.user_id, session=expdb_session)
except ForeignKeyConstraintError:
msg = f"Task {task_id} not found."
raise TaskNotFoundError(msg, code=472) from None
@@ -53,7 +52,7 @@ async def tag_task(
logger.info("Task {task_id} tagged '{tag}'.", task_id=task_id, tag=tag)
- tags = await database.tasks.get_tags(task_id, expdb_db)
+ tags = await database.tasks.get_tags(task_id, expdb_session)
return {
"task_tag": {"id": str(task_id), "tag": tags},
@@ -71,7 +70,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
async def fill_template(
template: str,
- task: Mapping[str, Any],
+ task: UntypedRow,
task_inputs: dict[str, str | int],
connection: AsyncConnection,
) -> dict[str, JSON]:
@@ -138,7 +137,7 @@ async def fill_template(
async def _fill_json_template( # noqa: C901
template: JSON,
- task: Mapping[str, Any],
+ task: UntypedRow,
task_inputs: dict[str, str | int],
fetched_data: dict[str, str],
connection: AsyncConnection,
@@ -193,7 +192,7 @@ async def _fill_json_template( # noqa: C901
template = template.replace(match.group(), fetched_data[field])
# I believe that the operations below are always part of string output, so
# we don't need to be careful to avoid losing typedness
- template = template.replace("[TASK:id]", str(task["task_id"]))
+ template = template.replace("[TASK:id]", str(task.task_id))
url = get_config().routing.server_url
server_url = f"{url.scheme}://{url.host}:{url.port}/"
return template.replace("[CONSTANT:base_url]", server_url)
@@ -452,6 +451,7 @@ async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
async def get_task(
task_id: int,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
+ expdb_session: Annotated[AsyncSession, Depends(expdb_session)],
) -> Task:
if not (task := await database.tasks.get(task_id, expdb)):
msg = f"Task {task_id} not found."
@@ -463,7 +463,7 @@ async def get_task(
task_input_rows, ttios, tags = await asyncio.gather(
database.tasks.get_input_for_task(task_id, expdb),
database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb),
- database.tasks.get_tags(task_id, expdb),
+ database.tasks.get_tags(task_id, expdb_session),
)
task_inputs = {
row.input: int(row.value) if row.value.isdigit() else row.value for row in task_input_rows
@@ -473,10 +473,7 @@ async def get_task(
(name, template) for name, io, required, template in templates if io == "input"
]
filled_templates = await asyncio.gather(
- *[
- fill_template(template, dict(task), task_inputs, expdb)
- for name, template in input_templates
- ],
+ *[fill_template(template, task, task_inputs, expdb) for name, template in input_templates],
)
inputs = [
filled | {"name": name}
diff --git a/tests/conftest.py b/tests/conftest.py
index 4191c47..7ba01fd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,10 +9,11 @@
import httpx
import pytest
from _pytest.config import Config # noqa: TC002 used during collection by Pytest
-from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest
from asgi_lifespan import LifespanManager
from sqlalchemy import text
+from sqlalchemy.ext.asyncio import AsyncSession
+import routers.dependencies
from config import (
Configuration,
DatabaseConfiguration,
@@ -27,6 +28,7 @@
from tests.users import OWNER_USER
if TYPE_CHECKING:
+ from _pytest.nodes import Item
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
@@ -71,6 +73,16 @@ async def expdb_test() -> AsyncIterator[AsyncConnection]:
yield connection
+@pytest.fixture
+async def expdb_session(expdb_test: AsyncConnection) -> AsyncIterator[AsyncSession]:
+ # It is possible that this session `commits`, does the connection
+ # rollback then still take effect? Probably not.
+ async with AsyncSession(expdb_test) as session:
+ yield session
+ # Do we here again need to do some check on whether there is an active transation?
+ await session.rollback()
+
+
@pytest.fixture
async def user_test() -> AsyncIterator[AsyncConnection]:
async with automatic_rollback(user_database()) as connection:
@@ -85,7 +97,7 @@ async def php_api() -> AsyncIterator[httpx.AsyncClient]:
yield client
-@pytest.fixture(scope="session")
+@pytest.fixture(scope="session", autouse=True)
async def app() -> AsyncIterator[FastAPI]:
config = Configuration(
openml_database=DatabaseConfiguration(database="openml"),
@@ -103,7 +115,9 @@ async def app() -> AsyncIterator[FastAPI]:
@pytest.fixture
async def py_api(
- expdb_test: AsyncConnection, user_test: AsyncConnection, app: FastAPI
+ expdb_test: AsyncConnection,
+ user_test: AsyncConnection,
+ app: FastAPI,
) -> AsyncIterator[httpx.AsyncClient]:
"""Create test client which automatically rolls back database updates on teardown."""
# Using the function-scoped database fixtures automatically benefits the
@@ -123,6 +137,8 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]:
app.dependency_overrides[expdb_connection] = override_expdb
app.dependency_overrides[userdb_connection] = override_userdb
+ routers.dependencies.expdb_connection = override_expdb
+
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://test",
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index cd58964..d83619d 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -2,11 +2,8 @@
from typing import TYPE_CHECKING
import pytest
-from sqlalchemy import select
-from sqlalchemy.ext.asyncio import AsyncSession
from core.errors import TASK_NOT_FOUND_DURING_TAG, TagAlreadyExistsError, TaskNotFoundError
-from database.schema.tags import TaskTag
from database.tasks import get_tags
from database.users import User
from routers.openml.tasks import tag_task
@@ -17,7 +14,7 @@
if TYPE_CHECKING:
import httpx
- from sqlalchemy.ext.asyncio import AsyncConnection
+ from sqlalchemy.ext.asyncio import AsyncSession
@pytest.mark.parametrize(
@@ -44,59 +41,51 @@ async def test_task_tag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncCli
[ADMIN_USER, SOME_USER, OWNER_USER],
ids=["administrator", "non-owner", "owner"],
)
-async def test_task_tag(user: User, expdb_test: AsyncConnection, task_factory: TaskFactory) -> None:
+async def test_task_tag(user: User, expdb_session: AsyncSession, task_factory: TaskFactory) -> None:
tag = "test_task_tag"
task = await task_factory()
- result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_db=expdb_test)
+ result = await tag_task(task_id=task.id, tag=tag, user=user, expdb_session=expdb_session)
assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}}
- tags = await get_tags(id_=task.id, connection=expdb_test)
+ tags = await get_tags(task_id=task.id, session=expdb_session)
assert tag in tags
@pytest.mark.mut
async def test_task_tag_returns_existing_tags(
- task_factory: TaskFactory, expdb_test: AsyncConnection
+ task_factory: TaskFactory, expdb_session: AsyncSession
) -> None:
task = await task_factory()
- await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_db=expdb_test)
- result = await tag_task(task_id=task.id, tag="second", user=ADMIN_USER, expdb_db=expdb_test)
+ await tag_task(task_id=task.id, tag="first", user=ADMIN_USER, expdb_session=expdb_session)
+ result = await tag_task(
+ task_id=task.id, tag="second", user=ADMIN_USER, expdb_session=expdb_session
+ )
assert result == {"task_tag": {"id": str(task.id), "tag": ["first", "second"]}}
@pytest.mark.mut
async def test_task_tag_fails_if_tag_exists(
- expdb_test: AsyncConnection, task_factory: TaskFactory
+ expdb_session: AsyncSession, task_factory: TaskFactory
) -> None:
tag = "fails_if_exist"
task = await task_factory()
- await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test)
+ await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_session=expdb_session)
with pytest.raises(TagAlreadyExistsError) as e:
- await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_db=expdb_test)
+ await tag_task(task_id=task.id, tag=tag, user=ADMIN_USER, expdb_session=expdb_session)
assert str(task.id) in e.value.detail
assert tag in e.value.detail
-async def test_task_tag_fails_if_task_does_not_exist(expdb_test: AsyncConnection) -> None:
+async def test_task_tag_fails_if_task_does_not_exist(expdb_session: AsyncSession) -> None:
task_id = 1_000_000
with pytest.raises(TaskNotFoundError) as e:
- await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_db=expdb_test)
+ await tag_task(task_id=task_id, tag="foo", user=ADMIN_USER, expdb_session=expdb_session)
assert str(task_id) in e.value.detail
task_not_found_in_tag_endpoint = TASK_NOT_FOUND_DURING_TAG
assert e.value.code == task_not_found_in_tag_endpoint
-async def test_if_reflection_works(py_api: httpx.Client, expdb_test: AsyncConnection) -> None:
- _ = py_api
-
- tag_stmt = select(TaskTag).where(TaskTag.tag == "OpenML100")
- async with AsyncSession(bind=expdb_test) as session:
- tags = (await session.scalars(tag_stmt)).all()
- tag_count = 100
- assert len(tags) == tag_count
-
-
# -- migration tests --
From 993cbd7a2c09515fc92ccfc23c93e875d0d2c8de Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 2 Jul 2026 08:39:23 +0200
Subject: [PATCH 6/8] Return ORM object
---
src/database/tasks.py | 5 ++---
src/routers/openml/tasks.py | 4 ++--
tests/conftest.py | 2 +-
tests/routers/openml/task_tag_test.py | 2 +-
4 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/src/database/tasks.py b/src/database/tasks.py
index 15d6589..9d1d120 100644
--- a/src/database/tasks.py
+++ b/src/database/tasks.py
@@ -149,10 +149,9 @@ async def get_task_type_inout_with_template(
return rows.all()
-async def get_tags(task_id: Identifier, session: AsyncSession) -> list[TagString]:
+async def get_tags(task_id: Identifier, session: AsyncSession) -> Sequence[TaskTag]:
stmt = select(TaskTag).where(TaskTag.entity_id == task_id)
- tags = (await session.scalars(stmt)).all()
- return [t.tag for t in tags]
+ return (await session.scalars(stmt)).all()
async def tag(
diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py
index 68aa47f..55df619 100644
--- a/src/routers/openml/tasks.py
+++ b/src/routers/openml/tasks.py
@@ -55,7 +55,7 @@ async def tag_task(
tags = await database.tasks.get_tags(task_id, expdb_session)
return {
- "task_tag": {"id": str(task_id), "tag": tags},
+ "task_tag": {"id": str(task_id), "tag": [t.tag for t in tags]},
}
@@ -496,5 +496,5 @@ async def get_task(
task_type=task_type.name,
input_=inputs,
output=outputs,
- tags=tags,
+ tags=[t.tag for t in tags],
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 7ba01fd..c5b6a4e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,6 +9,7 @@
import httpx
import pytest
from _pytest.config import Config # noqa: TC002 used during collection by Pytest
+from _pytest.nodes import Item # noqa: TC002 used during collection by Pytest
from asgi_lifespan import LifespanManager
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
@@ -28,7 +29,6 @@
from tests.users import OWNER_USER
if TYPE_CHECKING:
- from _pytest.nodes import Item
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py
index d83619d..f6e04b1 100644
--- a/tests/routers/openml/task_tag_test.py
+++ b/tests/routers/openml/task_tag_test.py
@@ -48,7 +48,7 @@ async def test_task_tag(user: User, expdb_session: AsyncSession, task_factory: T
assert result == {"task_tag": {"id": str(task.id), "tag": [tag]}}
tags = await get_tags(task_id=task.id, session=expdb_session)
- assert tag in tags
+ assert tag in [t.tag for t in tags]
@pytest.mark.mut
From 24bf0308830d99d4f685dbb7763e72d8ba2eb1dc Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Thu, 2 Jul 2026 06:41:58 +0000
Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---
src/database/setups.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/database/setups.py b/src/database/setups.py
index bd3cf33..1eddef6 100644
--- a/src/database/setups.py
+++ b/src/database/setups.py
@@ -5,13 +5,13 @@
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
-from database.schema.base import UntypedRow
from database.exceptions import (
_DUPLICATE_ENTRY,
_FOREIGN_KEY_CONSTRAINT_FAILED,
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
+from database.schema.base import UntypedRow
from routers.types import Identifier, TagString
if TYPE_CHECKING:
From 9541aab0121749eced9a0c42c4f97b3b6564a8db Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Thu, 2 Jul 2026 08:46:44 +0200
Subject: [PATCH 8/8] Use dependency injection instead of monkey patching
---
src/database/setups.py | 2 ++
src/routers/dependencies.py | 8 ++++----
tests/conftest.py | 3 ---
3 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/src/database/setups.py b/src/database/setups.py
index 1eddef6..c78986c 100644
--- a/src/database/setups.py
+++ b/src/database/setups.py
@@ -107,6 +107,8 @@ async def tag(
parameters={"setup_id": setup_id, "tag": tag, "user_id": user_id},
)
except IntegrityError as e:
+ if e.orig is None:
+ raise
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py
index 832c958..2ee8548 100644
--- a/src/routers/dependencies.py
+++ b/src/routers/dependencies.py
@@ -1,4 +1,3 @@
-import contextlib
from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Annotated
@@ -27,9 +26,10 @@ async def userdb_connection() -> AsyncIterator[AsyncConnection]:
yield connection
-async def expdb_session() -> AsyncIterator[AsyncSession]:
- conn = contextlib.asynccontextmanager(expdb_connection)
- async with conn() as connection, AsyncSession(connection) as session, session.begin():
+async def expdb_session(
+ connection: Annotated[AsyncConnection, Depends(expdb_connection)],
+) -> AsyncIterator[AsyncSession]:
+ async with AsyncSession(connection) as session, session.begin():
yield session
diff --git a/tests/conftest.py b/tests/conftest.py
index c5b6a4e..68a8e1c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -14,7 +14,6 @@
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
-import routers.dependencies
from config import (
Configuration,
DatabaseConfiguration,
@@ -137,8 +136,6 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]:
app.dependency_overrides[expdb_connection] = override_expdb
app.dependency_overrides[userdb_connection] = override_userdb
- routers.dependencies.expdb_connection = override_expdb
-
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://test",