diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 1cf80ac..1c118ae 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -337,7 +337,6 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 records, desc="downloading", total=total_records, - leave=False, mininterval=5, disable=self._quiet, unit=table.split(".")[-1], @@ -354,7 +353,11 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 newtables = self._database.expand_prefix(table, json_depth, keep_raw) if keep_raw: newtables = [table, *newtables] - indexable_attrs = [] + transform_elapsed = datetime.now(timezone.utc) - transform_started + + with tqdm(desc="indexing", disable=self._quiet) as progress: + index_started = datetime.now(timezone.utc) + self._database.index_prefix(table, progress) else: try: @@ -392,40 +395,46 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 finally: autocommit(self.db, self.dbtype, True) - transform_elapsed = datetime.now(timezone.utc) - transform_started - # Create indexes on id columns (for postgres) - index_started = datetime.now(timezone.utc) - if self.dbtype == DBType.POSTGRES: - - class PbarNoop: - def update(self, _: int) -> None: ... - def close(self) -> None: ... - - pbar: tqdm | PbarNoop = PbarNoop() # type:ignore[type-arg] - - index_total = len(indexable_attrs) - if not self._quiet: - pbar = tqdm( - desc="indexing", - total=index_total, - leave=False, - mininterval=3, - smoothing=0, - colour="#A9A9A9", - bar_format="{desc} {bar}{postfix}", - ) - for t, attr in indexable_attrs: - cur = self.db.cursor() - try: - cur.execute( - "CREATE INDEX ON " + sqlid(t) + " (" + sqlid(attr.name) + ")", + transform_elapsed = datetime.now(timezone.utc) - transform_started + + # Create indexes on id columns (for postgres) + index_started = datetime.now(timezone.utc) + if self.dbtype == DBType.POSTGRES: + + class PbarNoop: + def update(self, _: int) -> None: ... + def close(self) -> None: ... + + pbar: tqdm | PbarNoop = PbarNoop() # type:ignore[type-arg] + + index_total = len(indexable_attrs) + if not self._quiet: + pbar = tqdm( + desc="indexing", + total=index_total, + leave=False, + mininterval=3, + smoothing=0, + colour="#A9A9A9", + bar_format="{desc} {bar}{postfix}", ) - except (RuntimeError, psycopg.Error): - pass - finally: - cur.close() - pbar.update(1) - pbar.close() + for t, attr in indexable_attrs: + cur = self.db.cursor() + try: + cur.execute( + "CREATE INDEX ON " + + sqlid(t) + + " (" + + sqlid(attr.name) + + ")", + ) + except (RuntimeError, psycopg.Error): + pass + finally: + cur.close() + pbar.update(1) + pbar.close() + index_elapsed = datetime.now(timezone.utc) - index_started self._database.record_history( LoadHistory( diff --git a/src/ldlite/database/__init__.py b/src/ldlite/database/__init__.py index 719656c..9acd7b8 100644 --- a/src/ldlite/database/__init__.py +++ b/src/ldlite/database/__init__.py @@ -1,9 +1,16 @@ """A module for implementing ldlite database targets.""" -import datetime +from __future__ import annotations + from abc import ABC, abstractmethod -from collections.abc import Iterator from dataclasses import dataclass +from typing import TYPE_CHECKING, NoReturn + +if TYPE_CHECKING: + import datetime + from collections.abc import Iterator + + from tqdm import tqdm @dataclass(frozen=True) @@ -50,6 +57,10 @@ def ingest_records(self, prefix: str, records: Iterator[bytes]) -> int: def expand_prefix(self, prefix: str, json_depth: int, keep_raw: bool) -> list[str]: """Unnests and explodes the raw data at the given prefix.""" + @abstractmethod + def index_prefix(self, prefix: str, progress: tqdm[NoReturn] | None = None) -> None: + """Finds and indexes all tables at the given prefix.""" + @abstractmethod def record_history(self, history: LoadHistory) -> None: """Records the statistics and history of a single ldlite operation.""" diff --git a/src/ldlite/database/_typed_database.py b/src/ldlite/database/_typed_database.py index b67d367..eb4a1f3 100644 --- a/src/ldlite/database/_typed_database.py +++ b/src/ldlite/database/_typed_database.py @@ -1,9 +1,11 @@ # pyright: reportArgumentType=false +from __future__ import annotations + from abc import abstractmethod -from collections.abc import Callable, Sequence from contextlib import closing from datetime import timezone -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, NoReturn, TypeVar, cast +from uuid import uuid4 import psycopg from psycopg import sql @@ -13,7 +15,10 @@ from ._prefix import Prefix if TYPE_CHECKING: + from collections.abc import Callable, Sequence + import duckdb + from tqdm import tqdm DB = TypeVar("DB", bound="duckdb.DuckDBPyConnection | psycopg.Connection") @@ -247,6 +252,63 @@ def expand_prefix(self, prefix: str, json_depth: int, keep_raw: bool) -> list[st return created_tables + def index_prefix(self, prefix: str, progress: tqdm[NoReturn] | None = None) -> None: + pfx = Prefix(prefix) + with closing(self._conn_factory()) as conn: + with closing(conn.cursor()) as cur: + cur.execute( + """ +SELECT table_name FROM information_schema.tables +WHERE table_schema = $1 and table_name = $2;""", + ( + pfx.schema or self._default_schema, + pfx.catalog_table.name, + ), + ) + if len(cur.fetchall()) < 1: + return + + with closing(conn.cursor()) as cur: + cur.execute( + sql.SQL( + r""" +SELECT TABLE_NAME, COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS +WHERE + TABLE_SCHEMA = $1 AND + TABLE_NAME IN (SELECT TABLE_NAME FROM {catalog}) AND + ( + DATA_TYPE IN ('UUID', 'uuid') OR + COLUMN_NAME = 'id' OR + (COLUMN_NAME LIKE '%\_id' AND COLUMN_NAME <> '__id') + ); +""", + ) + .format(catalog=pfx.catalog_table.id) + .as_string(), + (pfx.schema or self._default_schema,), + ) + indexes = cur.fetchall() + + if progress is not None: + progress.total = len(indexes) + progress.refresh() + + for index in indexes: + with closing(conn.cursor()) as cur: + cur.execute( + sql.SQL("CREATE INDEX {name} ON {table} ({column});") + .format( + name=sql.Identifier(str(uuid4()).split("-")[0]), + table=sql.Identifier(*index[0].split(".")), + column=sql.Identifier(index[1]), + ) + .as_string(), + ) + if progress is not None: + progress.update(1) + + conn.commit() + def record_history(self, history: LoadHistory) -> None: with closing(self._conn_factory()) as conn, conn.cursor() as cur: cur.execute( diff --git a/tests/test_query.py b/tests/test_query.py index 96ce1e6..41ad44e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -51,8 +51,6 @@ def case_one_table(json_depth: int) -> QueryTC: "prefix__tcatalog": (["table_name"], [("prefix__t",)]), }, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), ], ) @@ -115,12 +113,8 @@ def case_two_tables(json_depth: int) -> QueryTC: ), }, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), - ("prefix__t__sub_objects", "__id"), ("prefix__t__sub_objects", "id"), - ("prefix__t__sub_objects", "sub_objects__o"), ("prefix__t__sub_objects", "sub_objects__id"), ], ) @@ -267,21 +261,11 @@ def case_three_tables(json_depth: int) -> QueryTC: ), }, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), - ("prefix__t__sub_objects", "__id"), ("prefix__t__sub_objects", "id"), - ("prefix__t__sub_objects", "sub_objects__o"), ("prefix__t__sub_objects", "sub_objects__id"), - ("prefix__t__sub_objects__sub_sub_objects", "__id"), ("prefix__t__sub_objects__sub_sub_objects", "id"), - ("prefix__t__sub_objects__sub_sub_objects", "sub_objects__o"), ("prefix__t__sub_objects__sub_sub_objects", "sub_objects__id"), - ( - "prefix__t__sub_objects__sub_sub_objects", - "sub_objects__sub_sub_objects__o", - ), ( "prefix__t__sub_objects__sub_sub_objects", "sub_objects__sub_sub_objects__id", @@ -327,8 +311,6 @@ def case_nested_object() -> QueryTC: ), }, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), ("prefix__t", "sub_object__id"), ], @@ -383,8 +365,6 @@ def case_doubly_nested_object() -> QueryTC: ), }, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), ("prefix__t", "sub_object__id"), ("prefix__t", "sub_object__sub_sub_object__id"), @@ -632,8 +612,6 @@ def case_indexing_id_like() -> QueryTC: ], expected_values={}, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), ("prefix__t", "other_id"), ("prefix__t", "an_id_but_with_a_different_ending"), @@ -666,7 +644,6 @@ def case_drop_raw(json_depth: int) -> QueryTC: "prefix__tcatalog": (["table_name"], [("prefix__t",)]), }, expected_indexes=[ - ("prefix__t", "__id"), ("prefix__t", "id"), ], ) @@ -692,8 +669,6 @@ def case_null_records() -> QueryTC: expected_tables=["prefix", "prefix__t", "prefix__tcatalog"], expected_values={}, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), ], ) @@ -727,8 +702,6 @@ def case_erm_keys() -> QueryTC: "prefix__tcatalog": (["table_name"], [("prefix__t",)]), }, expected_indexes=[ - ("prefix", "__id"), - ("prefix__t", "__id"), ("prefix__t", "id"), ], ) @@ -790,7 +763,7 @@ def _assert( .as_string(), ) for v in values: - assert cur.fetchone() == v + assert cur.fetchone() == v, str(v) assert cur.fetchone() is None @@ -804,7 +777,32 @@ def _assert( """, exp, ) - assert cur.fetchone() == (0,) + assert cur.fetchone() == (0,), str(exp) + + if tc.expected_indexes is not None: + for exp in tc.expected_indexes: + cur.execute( + """ + SELECT * FROM pg_indexes + WHERE tablename = $1 and indexdef LIKE $2; + """, + (exp[0], "%" + exp[1] + "%"), + ) + assert cur.fetchone() is not None, str(exp) + + indexed_tables = {exp[0] for exp in tc.expected_indexes} + where = ",".join([f"${n}" for n in range(1, len(indexed_tables) + 1)]) + cur.execute( + f""" + SELECT tablename, indexdef FROM pg_indexes + WHERE tablename IN ({where}) + ORDER BY tablename; + """, + list(indexed_tables), + ) + actual_indexes = cur.fetchall() + expected_indexes = tc.expected_indexes + assert len(actual_indexes) == len(expected_indexes) @mock.patch("httpx_folio.auth.httpx.post")