Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ repos:
additional_dependencies:
- fastapi
- pytest
- sqlalchemy

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.15.18'
Expand Down
9 changes: 3 additions & 6 deletions src/core/formatting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import html
from typing import TYPE_CHECKING

from config import get_config
from database.schema.base import UntypedRow
from schemas.datasets.openml import DatasetFileFormat

if TYPE_CHECKING:
from sqlalchemy.engine import Row


def _str_to_bool(string: str) -> bool:
if string.casefold() in ["true", "1", "yes", "y"]:
Expand All @@ -17,7 +14,7 @@ def _str_to_bool(string: str) -> bool:
raise ValueError(msg)


def _format_parquet_url(dataset: Row) -> str | None:
def _format_parquet_url(dataset: UntypedRow) -> str | None:
if dataset.format.lower() != DatasetFileFormat.ARFF:
return None

Expand All @@ -27,7 +24,7 @@ def _format_parquet_url(dataset: Row) -> str | None:
return f"{minio_base_url}datasets/{ten_thousands_prefix}/{padded_id}/dataset_{dataset.did}.pq"


def _format_dataset_url(dataset: Row) -> str:
def _format_dataset_url(dataset: UntypedRow) -> str:
base_url = get_config().routing.server_url
filename = f"{html.escape(dataset.name)}.{dataset.format.lower()}"
return f"{base_url}data/v1/download/{dataset.file_id}/{filename}"
Expand Down
14 changes: 8 additions & 6 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
DuplicatePrimaryKeyError,
ForeignKeyConstraintError,
)
from database.schema.base import UntypedRow
from routers.types import Identifier, TagString
from schemas.datasets.openml import DatasetStatus, Feature

if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection


async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
async def get(id_: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
Expand All @@ -35,7 +35,7 @@ async def get(id_: Identifier, connection: AsyncConnection) -> Row | None:
return row.one_or_none()


async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> Row | None:
async def get_file(*, file_id: Identifier, connection: AsyncConnection) -> UntypedRow | None:
row = await connection.execute(
text(
"""
Expand All @@ -53,7 +53,7 @@ async def get_tag(
dataset_id: Identifier,
tag: TagString,
connection: AsyncConnection,
) -> Row | None:
) -> UntypedRow | None:
return (
await connection.execute(
text(
Expand Down Expand Up @@ -111,6 +111,8 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
},
)
except IntegrityError as e:
if e.orig is None:
raise
code, msg = e.orig.args
if code == _FOREIGN_KEY_CONSTRAINT_FAILED:
raise ForeignKeyConstraintError(msg) from e
Expand All @@ -122,7 +124,7 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
async def get_description(
id_: Identifier,
connection: AsyncConnection,
) -> Row | None:
) -> UntypedRow | None:
"""Get the most recent description for the dataset."""
row = await connection.execute(
text(
Expand Down Expand Up @@ -160,7 +162,7 @@ async def get_status(id_: Identifier, connection: AsyncConnection) -> DatasetSta
async def get_latest_processing_update(
dataset_id: Identifier,
connection: AsyncConnection,
) -> Row | None:
) -> UntypedRow | None:
row = await connection.execute(
text(
"""
Expand Down
15 changes: 8 additions & 7 deletions src/database/evaluations.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from sqlalchemy import Row, text
from sqlalchemy import text

from core.formatting import _str_to_bool
from database.schema.base import UntypedRow
from schemas.datasets.openml import EstimationProcedure

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
async def get_math_functions(
function_type: str,
connection: AsyncConnection,
) -> Sequence[UntypedRow]:
rows = await connection.execute(
text(
"""
Expand All @@ -21,10 +25,7 @@ async def get_math_functions(function_type: str, connection: AsyncConnection) ->
),
parameters={"function_type": function_type},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()


async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]:
Expand Down
27 changes: 13 additions & 14 deletions src/database/flows.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from sqlalchemy import Row, text
from sqlalchemy import text

from database.schema.base import UntypedRow
from routers.types import Identifier

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection


async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
Expand All @@ -20,10 +21,7 @@ async def get_subflows(for_flow: Identifier, expdb: AsyncConnection) -> Sequence
),
parameters={"flow_id": for_flow},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()


async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
Expand All @@ -41,7 +39,7 @@ async def get_tags(flow_id: Identifier, expdb: AsyncConnection) -> list[str]:
return [tag.tag for tag in tag_rows]


async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[Row]:
async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequence[UntypedRow]:
rows = await expdb.execute(
text(
"""
Expand All @@ -52,13 +50,14 @@ async def get_parameters(flow_id: Identifier, expdb: AsyncConnection) -> Sequenc
),
parameters={"flow_id": flow_id},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()


async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
async def get_by_name(
name: str,
external_version: str,
expdb: AsyncConnection,
) -> UntypedRow | None:
"""Get flow by name and external version."""
row = await expdb.execute(
text(
Expand All @@ -73,7 +72,7 @@ async def get_by_name(name: str, external_version: str, expdb: AsyncConnection)
return row.one_or_none()


async def get(id_: Identifier, expdb: AsyncConnection) -> Row | None:
async def get(id_: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
row = await expdb.execute(
text(
"""
Expand Down
24 changes: 11 additions & 13 deletions src/database/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, cast

from sqlalchemy import Row, bindparam, text
from sqlalchemy import bindparam, text

from database.schema.base import UntypedRow
from routers.types import Identifier

if TYPE_CHECKING:
Expand All @@ -26,7 +27,7 @@ async def exist(id_: Identifier, expdb: AsyncConnection) -> bool:
return bool(row.one_or_none())


async def get(run_id: Identifier, expdb: AsyncConnection) -> Row | None:
async def get(run_id: Identifier, expdb: AsyncConnection) -> UntypedRow | None:
"""Fetch the core run row from the `run` table.

Returns the row if found, or None if no run with `run_id` exists.
Expand Down Expand Up @@ -63,7 +64,7 @@ async def get_tags(run_id: int, expdb: AsyncConnection) -> list[str]:
return [row.tag for row in rows.all()]


async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]:
"""Fetch the dataset(s) used as input for a run, with name and url.

Joins `input_data` with `dataset` to include the dataset name and ARFF URL.
Expand All @@ -79,10 +80,10 @@ async def get_input_data(run_id: int, expdb: AsyncConnection) -> list[Row]:
),
parameters={"run_id": run_id},
)
return cast("list[Row]", rows.all())
return cast("list[UntypedRow]", rows.all())


async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[UntypedRow]:
"""Fetch output files attached to a run from the `runfile` table.

Typical entries include the description XML and predictions ARFF.
Expand All @@ -98,15 +99,15 @@ async def get_output_files(run_id: int, expdb: AsyncConnection) -> list[Row]:
),
parameters={"run_id": run_id},
)
return cast("list[Row]", rows.all())
return cast("list[UntypedRow]", rows.all())


async def get_evaluations(
run_id: int,
expdb: AsyncConnection,
*,
evaluation_engine_ids: list[int],
) -> list[Row]:
) -> list[UntypedRow]:
"""Fetch evaluation metric results for a run.

Joins `evaluation` with `math_function` to resolve the metric name
Expand Down Expand Up @@ -138,10 +139,10 @@ async def get_evaluations(
query,
parameters={"run_id": run_id, "engine_ids": evaluation_engine_ids},
)
return cast("list[Row]", rows.all())
return cast("list[UntypedRow]", rows.all())


async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[UntypedRow]:
"""Get trace rows for a run from the trace table."""
rows = await expdb.execute(
text(
Expand All @@ -153,7 +154,4 @@ async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]:
),
parameters={"run_id": run_id},
)
return cast(
"Sequence[Row]",
rows.all(),
)
return rows.all()
1 change: 1 addition & 0 deletions src/database/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Defines Object-Relational Mappings (ORM)."""
52 changes: 52 additions & 0 deletions src/database/schema/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Base classes for all ORM classes.

When defining a new ORM class, use both `Base` and one of the `DeferredReflection` subclasses to
make sure that the class is populated with attributes that may not be defined explicitly.
For example, when creating a new mapping for a table from the `openml_expdb` database, use:

class ClassName(ExpDBReflected, Base):
__tablename__ = "class_names"

# any columns you wanted mapped explicitly
...

"""

from typing import Any

from sqlalchemy import Row
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import DeclarativeBase

from database.setup import expdb_database, user_database

UntypedRow = Row[Any]


class Base(DeclarativeBase):
"""Base class for all ORM classes."""


class ExpDBReflected(DeferredReflection):
"""Base class for ORM classes to map onto a table in the `openml_expdb` database."""

__abstract__ = True


class UserDBReflected(DeferredReflection):
"""Base class for ORM classes to map onto a table in the `openml` database."""

__abstract__ = True


async def reflect_db_schemas() -> None:
"""Populate defined ORM classes with attributes defined from columns in the database.

For example, the `dataset` class would automatically get a `collection_date` attribute,
even if it wasn't explicitly declared in the class definition,
because the `openml_expdb.dataset` table has a column `collection_date`.
"""
async with user_database().connect() as connection:
await connection.run_sync(UserDBReflected.prepare) # type: ignore[arg-type] # run_sync expects positional-only arg but `prepare` does not have it.
async with expdb_database().connect() as connection:
await connection.run_sync(ExpDBReflected.prepare) # type: ignore[arg-type] # as above.
30 changes: 30 additions & 0 deletions src/database/schema/tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""ORM classes for the *_tag tables (task_tag, ...)."""

from datetime import datetime

from sqlalchemy import FetchedValue
from sqlalchemy.orm import Mapped, mapped_column

from database.schema.base import Base, ExpDBReflected
from routers.types import Identifier, TagString


class Tag:
"""Base class for all of the *_tag tables."""

# The identifier of the entity that is tagged (e.g., dataset id, task id)
entity_id: Mapped[Identifier] = mapped_column("id", primary_key=True)
tag: Mapped[TagString] = mapped_column(primary_key=True)
uploader_id: Mapped[Identifier] = mapped_column("uploader")
creation_date: Mapped[datetime] = mapped_column("date", server_default=FetchedValue())


class TaskTag(ExpDBReflected, Tag, Base):
"""Tags belonging to a task."""

__tablename__ = "task_tag"

@property
def task_id(self) -> Identifier:
"""Identifier of the task which is tagged by this tag."""
return self.entity_id
Loading
Loading