diff --git a/cardinal_pythonlib/sqlalchemy/alembic_func.py b/cardinal_pythonlib/sqlalchemy/alembic_func.py index 1896ce6..7d7507c 100644 --- a/cardinal_pythonlib/sqlalchemy/alembic_func.py +++ b/cardinal_pythonlib/sqlalchemy/alembic_func.py @@ -76,13 +76,15 @@ def get_head_revision_from_alembic( ) -> str: """ Ask Alembic what its head revision is (i.e. where the Python code would - like the database to be at). + like the database to be at). This does not read the database. Arguments: - alembic_config_filename: config filename - alembic_base_dir: directory to start in, so relative paths in the - config file work. - version_table: table name for Alembic versions + alembic_config_filename: + config filename + alembic_base_dir: + directory to start in, so relative paths in the config file work. + version_table: + table name for Alembic versions """ if alembic_base_dir is None: alembic_base_dir = os.path.dirname(alembic_config_filename) @@ -148,6 +150,7 @@ def get_current_and_head_revision( @preserve_cwd def upgrade_database( alembic_config_filename: str, + db_url: str = None, alembic_base_dir: str = None, starting_revision: str = None, destination_revision: str = "head", @@ -164,6 +167,9 @@ def upgrade_database( alembic_config_filename: config filename + db_url: + Optional database URL to use, by way of override. + alembic_base_dir: directory to start in, so relative paths in the config file work @@ -187,6 +193,8 @@ def upgrade_database( alembic_base_dir = os.path.dirname(alembic_config_filename) os.chdir(alembic_base_dir) # so the directory in the config file works config = Config(alembic_config_filename) + if db_url: + config.set_main_option("sqlalchemy.url", db_url) script = ScriptDirectory.from_config(config) # noinspection PyUnusedLocal,PyProtectedMember @@ -217,6 +225,7 @@ def upgrade(rev, context): def downgrade_database( alembic_config_filename: str, destination_revision: str, + db_url: str = None, alembic_base_dir: str = None, starting_revision: str = None, version_table: str = DEFAULT_ALEMBIC_VERSION_TABLE, @@ -233,6 +242,9 @@ def downgrade_database( alembic_config_filename: config filename + db_url: + Optional database URL to use, by way of override. + alembic_base_dir: directory to start in, so relative paths in the config file work @@ -255,6 +267,8 @@ def downgrade_database( alembic_base_dir = os.path.dirname(alembic_config_filename) os.chdir(alembic_base_dir) # so the directory in the config file works config = Config(alembic_config_filename) + if db_url: + config.set_main_option("sqlalchemy.url", db_url) script = ScriptDirectory.from_config(config) # noinspection PyUnusedLocal,PyProtectedMember @@ -403,6 +417,9 @@ def stamp_allowing_unusual_version_table( This function is a clone of ``alembic.command.stamp()``, but allowing ``version_table`` to change. See https://alembic.zzzcomputing.com/en/latest/api/commands.html#alembic.command.stamp + + Note that the Config object can include the database URL; use + ``config.set_main_option("sqlalchemy.url", db_url)``. """ script = ScriptDirectory.from_config(config) diff --git a/cardinal_pythonlib/sqlalchemy/core_query.py b/cardinal_pythonlib/sqlalchemy/core_query.py index 8b37dc8..f25710f 100644 --- a/cardinal_pythonlib/sqlalchemy/core_query.py +++ b/cardinal_pythonlib/sqlalchemy/core_query.py @@ -24,6 +24,31 @@ **Query helper functions using the SQLAlchemy Core.** +Example of result types in SQLAlchemy 1.4+ and higher: + +.. code-block:: python + + from typing import List + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.engine.result import MappingResult, Result + from sqlalchemy.engine.row import Row, RowMapping + + query = ( + select(text("*")) + .select_from(table(some_tablename)) + ) + + # As tuples: + result_1: CursorResult = session.execute(query) + # ... or, more generically, of type Result + like_unnamed_tuples: List[Row] = result_1.fetchall() + + # Or: + result_2: Result = session.execute(query) + mapping_result: Mapping_Result = result_2.mappings() + like_dicts: List[RowMapping] = list(mapping_result) # implicit fetchall() + # ... or could have done: like_dicts = result_2.mappings().fetchall() + """ from typing import Any, List, Optional, Tuple, Union diff --git a/cardinal_pythonlib/sqlalchemy/merge_db.py b/cardinal_pythonlib/sqlalchemy/merge_db.py index 4c5ef52..d576c5a 100644 --- a/cardinal_pythonlib/sqlalchemy/merge_db.py +++ b/cardinal_pythonlib/sqlalchemy/merge_db.py @@ -61,13 +61,12 @@ """ +from functools import total_ordering import logging -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Any, Callable, Dict, List, Set, Tuple, Type from sqlalchemy.engine.base import Engine from sqlalchemy.orm import lazyload, load_only - -# noinspection PyProtectedMember from sqlalchemy.orm.session import make_transient, Session, sessionmaker from sqlalchemy.schema import sort_tables from sqlalchemy.sql.schema import MetaData, Table @@ -100,6 +99,7 @@ # ============================================================================= +@total_ordering class TableDependency(object): """ Stores a table dependency for use in functions such as @@ -153,6 +153,21 @@ def __repr__(self) -> str: f"depends on {self.parent_tablename!r})" ) + def __lt__(self, other: "TableDependency") -> bool: + """ + Define a sort order. + """ + return (self.child_tablename, self.parent_tablename) < ( + other.child_tablename, + other.parent_tablename, + ) + + def __eq__(self, other: "TableDependency") -> bool: + return ( + self.child_tablename == other.child_tablename + and self.parent_tablename == other.parent_tablename + ) + def set_metadata(self, metadata: MetaData) -> None: """ Sets the metadata for the parent and child tables. @@ -204,10 +219,55 @@ def sqla_tuple(self) -> Tuple[Table, Table]: return self.parent_table, self.child_table +def _get_dependencies_for_table( + table: Table, even_use_alter: bool = False +) -> Set[Tuple[Table, Table]]: + """ + Returns dependencies for a single table. + + Args: + table: + A SQLAlchemy Table object. + even_use_alter: + Even include relationships with ``use_alter`` set. See + https://docs.sqlalchemy.org/en/latest/core/constraints.html#sqlalchemy.schema.ForeignKeyConstraint.params.use_alter + + Returns: + A set of tuples of Tables: (parent_that_this_table_dependent_on, + this_table_child). + + See :func:`sqlalchemy.sql.ddl.sort_tables_and_constraints` for method. + """ + dependencies: Set[Tuple[Table, Table]] = set() + # Add via (a) foreign_key_constraints, and (b) _extra_dependencies. This is + # an SQLAlchemy internal; see its sort_tables_and_constraints function as + # above. + # log.debug( + # f"_get_dependencies_for_table: {table.name=}; " + # f"{len(table.foreign_key_constraints)=}" + # ) + for fkc in table.foreign_key_constraints: + # log.debug(f"- {fkc.use_alter=}; {fkc.referred_table.name=}") + if fkc.use_alter is True and not even_use_alter: + continue + dependent_on = fkc.referred_table + if dependent_on is not table: + dependencies.add((dependent_on, table)) + if hasattr(table, "_extra_dependencies"): + # noinspection PyProtectedMember + dependencies.update( + (parent, table) for parent in table._extra_dependencies + ) + return dependencies + + def get_all_dependencies( metadata: MetaData, extra_dependencies: List[TableDependency] = None, - sort: bool = True, + skip_dependencies: List[TableDependency] = None, + sort: bool = False, + even_use_alter: bool = False, + debug: bool = False, ) -> List[TableDependency]: """ Describes how the tables found in the metadata depend on each other. @@ -215,41 +275,55 @@ def get_all_dependencies( on A.) Args: - metadata: the metadata to inspect - extra_dependencies: additional table dependencies to specify manually - sort: sort into alphabetical order of (parent, child) table names? + metadata: + The metadata to inspect. + extra_dependencies: + Additional table dependencies to specify manually. + skip_dependencies: + Additional table dependencies to IGNORE. + sort: + Sort into alphabetical order of (parent, child) table names? + even_use_alter: + Even include relationships with ``use_alter`` set. See SQLAlchemy + documentation. + debug: + Show debugging information. Returns: a list of :class:`TableDependency` objects - - See :func:`sort_tables_and_constraints` for method. """ - extra_dependencies = ( - extra_dependencies or [] - ) # type: List[TableDependency] + # First deal with user-specified dependencies. + extra_dependencies: List[TableDependency] = extra_dependencies or [] for td in extra_dependencies: td.set_metadata_if_none(metadata) - dependencies = set([td.sqla_tuple() for td in extra_dependencies]) - - tables = list(metadata.tables.values()) # type: List[Table] + dependencies: Set[Tuple[Table, Table]] = set( + [td.sqla_tuple() for td in extra_dependencies] + ) + if debug: + readable = [str(td) for td in extra_dependencies] + log.debug(f"get_all_dependencies: user specified: {readable!r}") + # Add dependencies from tables. + tables: List[Table] = list(metadata.tables.values()) for table in tables: - for fkc in table.foreign_key_constraints: - if fkc.use_alter is True: - # http://docs.sqlalchemy.org/en/latest/core/constraints.html#sqlalchemy.schema.ForeignKeyConstraint.params.use_alter # noqa: E501 - continue - - dependent_on = fkc.referred_table - if dependent_on is not table: - dependencies.add((dependent_on, table)) - - if hasattr(table, "_extra_dependencies"): - # noinspection PyProtectedMember - dependencies.update( - (parent, table) for parent in table._extra_dependencies + tdep = _get_dependencies_for_table( + table, even_use_alter=even_use_alter + ) + if debug: + parents = [tt[0].name for tt in tdep] + log.debug( + f"get_all_dependencies: for table {table.name!r}, " + f"adding dependencies: {parents}" ) + dependencies.update(tdep) - dependencies = [ + # Remove explicitly specified dependencies to skip. + skip_dependencies: List[TableDependency] = skip_dependencies or [] + for sd in skip_dependencies: + dependencies.remove(sd.sqla_tuple()) + + # Convert from set to list + dependencies: List[TableDependency] = [ TableDependency(parent_table=parent, child_table=child) for parent, child in dependencies ] @@ -282,11 +356,11 @@ def __init__( children: its children (things that depend on it) parents: its parents (things that it depends on) """ - self.table = table - self.children = children or [] # type: List[Table] - self.parents = parents or [] # type: List[Table] - self.circular = False - self.circular_chain = [] # type: List[Table] + self.table: Table = table + self.children: List[Table] = children or [] + self.parents: List[Table] = parents or [] + self.circular: bool = False + self.circular_chain: List[Table] = [] @property def is_child(self) -> bool: @@ -340,7 +414,7 @@ def set_circular(self, circular: bool, chain: List[Table] = None) -> None: participating in the circular chain """ self.circular = circular - self.circular_chain = chain or [] # type: List[Table] + self.circular_chain = chain or [] @property def circular_description(self) -> str: @@ -367,7 +441,11 @@ def description(self) -> str: return desc def __str__(self) -> str: - return f"{self.tablename}:{self.description}" + ptxt = ", ".join(sorted(p.name for p in self.parents)) + circ = ( + f"; CIRCULAR({self.circular_description})" if self.circular else "" + ) + return f"{self.tablename}(depends on [{ptxt}]{circ})" def __repr__(self) -> str: return ( @@ -379,6 +457,9 @@ def __repr__(self) -> str: def classify_tables_by_dependency_type( metadata: MetaData, extra_dependencies: List[TableDependency] = None, + skip_dependencies: List[TableDependency] = None, + all_dependencies: List[TableDependency] = None, + even_use_alter: bool = False, sort: bool = True, ) -> List[TableDependencyClassification]: """ @@ -386,18 +467,36 @@ def classify_tables_by_dependency_type( and returns a list of objects describing their dependencies. Args: - metadata: the :class:`MetaData` to inspect - extra_dependencies: additional dependencies - sort: sort the results by table name? + metadata: + the :class:`MetaData` to inspect + extra_dependencies: + Additional dependencies. (Not used if you specify + all_dependencies.) + skip_dependencies: + Additional table dependencies to IGNORE. (Not used if you specify + all_dependencies.) + all_dependencies: + If you have precalculated all dependencies, you can pass that in + here, to save redoing the work. + even_use_alter: + Even include relationships with ``use_alter`` set. See SQLAlchemy + documentation. (Not used if you specify all_dependencies.) + sort: + sort the results by table name? Returns: list of :class:`TableDependencyClassification` objects, one for each table """ - tables = list(metadata.tables.values()) # type: List[Table] - all_deps = get_all_dependencies(metadata, extra_dependencies) - tdcmap = {} # type: Dict[Table, TableDependencyClassification] + tables: List[Table] = list(metadata.tables.values()) + all_deps = all_dependencies or get_all_dependencies( + metadata=metadata, + extra_dependencies=extra_dependencies, + skip_dependencies=skip_dependencies, + even_use_alter=even_use_alter, + ) + tdcmap: Dict[Table, TableDependencyClassification] = {} for table in tables: parents = [ td.parent_table for td in all_deps if td.child_table == table @@ -411,25 +510,37 @@ def classify_tables_by_dependency_type( # Check for circularity def parents_contain( - start: Table, probe: Table + start: Table, probe: Table, seen: Set[Table] = None ) -> Tuple[bool, List[Table]]: + seen = seen or set() tdc_ = tdcmap[start] if probe in tdc_.parents: return True, [start, probe] for parent in tdc_.parents: - contains_, chain_ = parents_contain(start=parent, probe=probe) + if parent in seen: + continue # avoid infinite recursion + seen.add(parent) + contains_, chain_ = parents_contain( + start=parent, probe=probe, seen=seen + ) if contains_: return True, [start] + chain_ return False, [] def children_contain( - start: Table, probe: Table + start: Table, probe: Table, seen: Set[Table] = None ) -> Tuple[bool, List[Table]]: + seen = seen or set() tdc_ = tdcmap[start] if probe in tdc_.children: return True, [start, probe] for child in tdc_.children: - contains_, chain_ = children_contain(start=child, probe=probe) + if child in seen: + continue # avoid infinite recursion + seen.add(child) + contains_, chain_ = children_contain( + start=child, probe=probe, seen=seen + ) if contains_: return True, [start] + chain_ return False, [] @@ -548,8 +659,51 @@ def __init__( self.src_engine = src_engine self.dst_engine = dst_engine self.src_table_names = src_table_names - self.missing_src_columns = missing_src_columns or [] # type: List[str] - self.info = info or {} # type: Dict[str, Any] + self.missing_src_columns: List[str] = missing_src_columns or [] + self.info: Dict[str, Any] = info or {} + + +# ============================================================================= +# suggest_table_order (for merge_db) +# ============================================================================= + + +def suggest_table_order( + classified_tables: List[TableDependencyClassification], +) -> List[Table]: + """ + Suggest an order to process tables in, according to precalculated + dependencies. + + Args: + classified_tables: + The tables, with dependency information. + + Returns: + A list of the tables, sorted into a sensible order. + """ + # We can't handle a circular situation: + assert not any( + tdc.circular for tdc in classified_tables + ), "Can't handle circular references between tables" + # Keeping track. With a quasi-arbitrary starting order: + to_do: Set[TableDependencyClassification] = set(classified_tables) + tables_done: Set[Table] = set() + final_order: List[TableDependencyClassification] = [] + + # Now, iteratively: + while to_do: + suitable = [ + tdc for tdc in to_do if all(p in tables_done for p in tdc.parents) + ] + if not suitable: + raise ValueError("suggest_table_order: Unable to solve") + suitable.sort(key=lambda ct: ct.table.name) + final_order.extend(suitable) + to_do -= set(suitable) + tables_done.update(tdc.table for tdc in suitable) + + return [tdc.table for tdc in final_order] # ============================================================================= @@ -568,6 +722,7 @@ def merge_db( only_tables: List[TableIdentity] = None, tables_to_keep_pks_for: List[TableIdentity] = None, extra_table_dependencies: List[TableDependency] = None, + skip_table_dependencies: List[TableDependency] = None, dummy_run: bool = False, info_only: bool = False, report_every: int = 1000, @@ -577,6 +732,12 @@ def merge_db( commit_at_end: bool = True, prevent_eager_load: bool = True, trcon_info: Dict[str, Any] = None, + even_use_alter_relationships: bool = False, + debug_table_structure: bool = False, + debug_table_dependencies: bool = False, + debug_copy_sqla_object: bool = False, + debug_rewrite_relationships: bool = False, + use_sqlalchemy_order: bool = True, ) -> None: """ Copies an entire database as far as it is described by ``metadata`` and @@ -652,7 +813,11 @@ def my_translate_fn(trcon: TranslationContext) -> None: :class:`TableIdentity`) extra_table_dependencies: - optional list of :class:`TableDependency` objects (q.v.) + optional list of :class:`TableDependency` objects (q.v.) to include + + skip_table_dependencies: + optional list of :class:`TableDependency` objects (q.v.) to IGNORE; + unusual dummy_run: don't alter the destination database @@ -682,6 +847,26 @@ def my_translate_fn(trcon: TranslationContext) -> None: trcon_info: additional dictionary passed to ``TranslationContext.info`` (see :class:`.TranslationContext`) + + even_use_alter_relationships: + Even include relationships with ``use_alter`` set. See SQLAlchemy + documentation. + + debug_table_structure: + Debug table structure? Can be long-winded. + + debug_table_dependencies: + Debug calculating table dependencies? + + debug_copy_sqla_object: + Debug copying objects? + + debug_rewrite_relationships: + Debug rewriting ORM relationships? + + use_sqlalchemy_order: + If true, use the table order suggested by SQLAlchemy. If false, + calculate our own. """ log.info("merge_db(): starting") @@ -694,20 +879,21 @@ def my_translate_fn(trcon: TranslationContext) -> None: return # Finalize parameters - skip_tables = skip_tables or [] # type: List[TableIdentity] - only_tables = only_tables or [] # type: List[TableIdentity] - tables_to_keep_pks_for = ( - tables_to_keep_pks_for or [] - ) # type: List[TableIdentity] - extra_table_dependencies = ( + skip_tables: List[TableIdentity] = skip_tables or [] + only_tables: List[TableIdentity] = only_tables or [] + tables_to_keep_pks_for: List[TableIdentity] = tables_to_keep_pks_for or [] + extra_table_dependencies: List[TableDependency] = ( extra_table_dependencies or [] - ) # type: List[TableDependency] - trcon_info = trcon_info or {} # type: Dict[str, Any] + ) + skip_table_dependencies: List[TableDependency] = ( + skip_table_dependencies or [] + ) + trcon_info: Dict[str, Any] = trcon_info or {} # We need both Core and ORM for the source. # noinspection PyUnresolvedReferences - metadata = base_class.metadata # type: MetaData - src_session = sessionmaker(bind=src_engine, future=True)() # type: Session + metadata: MetaData = base_class.metadata + src_session: Session = sessionmaker(bind=src_engine, future=True)() dst_engine = get_engine_from_session(dst_session) tablename_to_ormclass = get_orm_classes_by_table_name_from_base(base_class) @@ -717,13 +903,15 @@ def my_translate_fn(trcon: TranslationContext) -> None: ti.set_metadata_if_none(metadata) for td in extra_table_dependencies: td.set_metadata_if_none(metadata) + for td in skip_table_dependencies: + td.set_metadata_if_none(metadata) # Get all lists of tables as their names skip_table_names = [ti.tablename for ti in skip_tables] only_table_names = [ti.tablename for ti in only_tables] - tables_to_keep_pks_for = [ + tables_to_keep_pks_for: List[str] = [ ti.tablename for ti in tables_to_keep_pks_for - ] # type: List[str] + ] # ... now all are of type List[str] # Safety check: this is an imperfect check for source == destination, but @@ -754,37 +942,47 @@ def my_translate_fn(trcon: TranslationContext) -> None: table_num = 0 overall_record_num = 0 - tables = list(metadata.tables.values()) # type: List[Table] - # Very helpfully, MetaData.sorted_tables produces tables in order of - # relationship dependency ("each table is preceded by all tables which - # it references"); - # http://docs.sqlalchemy.org/en/latest/core/metadata.html - # HOWEVER, it only works if you specify ForeignKey relationships - # explicitly. - # We can also add in user-specified dependencies, and therefore can do the - # sorting in one step with sqlalchemy.schema.sort_tables: - ordered_tables = sort_tables( - tables, - extra_dependencies=[ - td.sqla_tuple() for td in extra_table_dependencies - ], + all_dependencies = get_all_dependencies( + metadata=metadata, + extra_dependencies=extra_table_dependencies, + skip_dependencies=skip_table_dependencies, + debug=debug_table_dependencies, + even_use_alter=even_use_alter_relationships, ) - # Note that the ordering is NOT NECESSARILY CONSISTENT, though (in that - # the order of stuff it doesn't care about varies across runs). - all_dependencies = get_all_dependencies(metadata, extra_table_dependencies) dep_classifications = classify_tables_by_dependency_type( - metadata, extra_table_dependencies + metadata, + all_dependencies=all_dependencies, + even_use_alter=even_use_alter_relationships, ) circular = [tdc for tdc in dep_classifications if tdc.circular] assert not circular, f"Circular dependencies! {circular!r}" + all_dependencies.sort() # cosmetic log.debug( "All table dependencies: " - + "; ".join(str(td) for td in all_dependencies) - ) - log.debug( - "Table dependency classifications: " + "; ".join(str(c) for c in dep_classifications) ) + tables: List[Table] = list(metadata.tables.values()) + if use_sqlalchemy_order: + # Very helpfully, MetaData.sorted_tables produces tables in order of + # relationship dependency ("each table is preceded by all tables which + # it references"); + # http://docs.sqlalchemy.org/en/latest/core/metadata.html + # HOWEVER, it only works if you specify ForeignKey relationships + # explicitly. + # We can also add in user-specified dependencies, and therefore can do + # the sorting in one step with sqlalchemy.schema.sort_tables: + log.debug("Using SQLAlchemy's suggested table order") + ordered_tables = sort_tables( + tables, + extra_dependencies=[ + td.sqla_tuple() for td in extra_table_dependencies + ], + ) + # Note that the ordering is NOT NECESSARILY CONSISTENT, though (in that + # the order of stuff it doesn't care about varies across runs). + else: + log.debug("Calculating table order without SQLAlchemy") + ordered_tables = suggest_table_order(dep_classifications) log.info( "Processing tables in the order: " + repr([table.name for table in ordered_tables]) @@ -829,15 +1027,13 @@ def translate(oldobj_: object, newobj_: object) -> object: tablename = table.name if tablename in skip_table_names: - log.info(f"... skipping table {tablename!r} (as per skip_tables)") + log.info(f"Skipping table {tablename!r} (as per skip_tables)") continue if only_table_names and tablename not in only_table_names: - log.info(f"... ignoring table {tablename!r} (as per only_tables)") + log.info(f"Ignoring table {tablename!r} (as per only_tables)") continue if allow_missing_src_tables and tablename not in src_tables: - log.info( - f"... ignoring table {tablename!r} (not in source database)" - ) + log.info(f"Ignoring table {tablename!r} (not in source database)") continue table_num += 1 table_record_num = 0 @@ -859,18 +1055,21 @@ def translate(oldobj_: object, newobj_: object) -> object: tdc = [tdc for tdc in dep_classifications if tdc.table == table][0] log.info(f"Processing table {tablename!r} via ORM class {orm_class!r}") - log.debug(f"PK attributes: {pk_attrs!r}") - log.debug(f"Table: {table!r}") - log.debug( - f"Dependencies: parents = {tdc.parent_names!r}; " - f"children = {tdc.child_names!r}" - ) + if debug_table_structure: + log.debug(f"PK attributes: {pk_attrs!r}") + log.debug(f"Table: {table!r}") + if debug_table_dependencies: + log.debug( + f"Dependencies: parents = {tdc.parent_names!r}; " + f"children = {tdc.child_names!r}" + ) if info_only: log.debug("info_only; skipping table contents") continue def wipe_primary_key(inst: object) -> None: + # Defined here because it uses pk_attrs for attrname in pk_attrs: setattr(inst, attrname, None) @@ -920,7 +1119,9 @@ def wipe_primary_key(inst: object) -> None: # maintain a copy of the old object, make a copy using # copy_sqla_object, and re-assign relationships accordingly. - for instance in query.all(): + instances = list(query.all()) + log.info(f"... processing {len(instances)} records") + for instance in instances: # log.debug(f"Source instance: {instance!r}") table_record_num += 1 overall_record_num += 1 @@ -966,14 +1167,14 @@ def wipe_primary_key(inst: object) -> None: omit_pk=wipe_pk, omit_fk=True, omit_attrs=missing_attrs, - debug=False, + debug=debug_copy_sqla_object, ) rewrite_relationships( oldobj, newobj, objmap, - debug=False, + debug=debug_rewrite_relationships, skip_table_names=skip_table_names, ) @@ -986,7 +1187,12 @@ def wipe_primary_key(inst: object) -> None: # new PK will be created when session is flushed if tdc.is_parent: - objmap[oldobj] = newobj # for its children's benefit + try: + objmap[oldobj] = newobj # for its children's benefit + except KeyError: + raise KeyError( + f"Missing attribute {oldobj=} in {objmap=}" + ) if flush_per_record: flush() diff --git a/cardinal_pythonlib/sqlalchemy/orm_inspect.py b/cardinal_pythonlib/sqlalchemy/orm_inspect.py index 5fff4dc..38980c7 100644 --- a/cardinal_pythonlib/sqlalchemy/orm_inspect.py +++ b/cardinal_pythonlib/sqlalchemy/orm_inspect.py @@ -26,6 +26,7 @@ """ +import logging from typing import ( Dict, Generator, @@ -37,7 +38,6 @@ Union, ) -# noinspection PyProtectedMember from sqlalchemy import inspect from sqlalchemy.orm.base import class_mapper from sqlalchemy.orm.mapper import Mapper @@ -51,13 +51,12 @@ from cardinal_pythonlib.classes import gen_all_subclasses from cardinal_pythonlib.enumlike import OrderedNamespace from cardinal_pythonlib.dicts import reversedict -from cardinal_pythonlib.logs import get_brace_style_log_with_null_handler if TYPE_CHECKING: from sqlalchemy.orm.state import InstanceState from sqlalchemy.sql.schema import Table -log = get_brace_style_log_with_null_handler(__name__) +log = logging.getLogger(__name__) # ============================================================================= @@ -253,7 +252,7 @@ def walk_orm_tree( continue seen.add(obj) if debug: - log.debug("walk: yielding {!r}", obj) + log.debug(f"walk: yielding {obj!r}") yield obj insp = inspect(obj) # type: InstanceState for ( @@ -272,10 +271,10 @@ def walk_orm_tree( continue # Process relationship if debug: - log.debug("walk: following relationship {}", relationship) + log.debug(f"walk: following relationship {relationship}") related = getattr(obj, attrname) if debug and related: - log.debug("walk: queueing {!r}", related) + log.debug(f"walk: queueing {related!r}") if relationship.uselist: stack.extend(related) elif related is not None: @@ -331,7 +330,7 @@ def copy_sqla_object( prohibited |= fk_keys prohibited |= set(omit_attrs) if debug: - log.debug("copy_sqla_object: skipping: {}", prohibited) + log.debug(f"copy_sqla_object: skipping: {prohibited}") for k in [ p.key for p in mapper.iterate_properties if p.key not in prohibited ]: @@ -339,12 +338,12 @@ def copy_sqla_object( value = getattr(obj, k) if debug: log.debug( - "copy_sqla_object: processing attribute {} = {}", k, value + f"copy_sqla_object: processing attribute {k} = {value}" ) setattr(newobj, k, value) except AttributeError: if debug: - log.debug("copy_sqla_object: failed attribute {}", k) + log.debug(f"copy_sqla_object: failed attribute {k}") pass return newobj @@ -412,24 +411,29 @@ def rewrite_relationships( # insp.mapper.relationships is of type # sqlalchemy.utils._collections.ImmutableProperties, which is basically # a sort of AttrDict. - for ( - attrname_rel - ) in ( + attrname_rel_list = list( insp.mapper.relationships.items() - ): # type: Tuple[str, RelationshipProperty] + ) # type: List[Tuple[str, RelationshipProperty]] + if debug: + log.debug( + f"rewrite_relationships: relationships are {attrname_rel_list}" + ) + for attrname_rel in attrname_rel_list: attrname = attrname_rel[0] rel_prop = attrname_rel[1] if rel_prop.viewonly: if debug: - log.debug("Skipping viewonly relationship") + log.debug( + "rewrite_relationships: Skipping viewonly relationship" + ) continue # don't attempt to write viewonly relationships related_class = rel_prop.mapper.class_ related_table_name = related_class.__tablename__ # type: str if related_table_name in skip_table_names: if debug: log.debug( - "Skipping relationship for related table {!r}", - related_table_name, + f"rewrite_relationships: Skipping relationship for " + f"related table {related_table_name!r}" ) continue # The relationship is an abstract object (so getting the @@ -439,17 +443,30 @@ def rewrite_relationships( # rel_key = rel.key # type: str # ... but also available from the mapper as attrname, above related_old = getattr(oldobj, attrname) - if rel_prop.uselist: - related_new = [objmap[r] for r in related_old] - elif related_old is not None: - related_new = objmap[related_old] - else: - related_new = None + try: + if rel_prop.uselist: + related_new = [objmap[r] for r in related_old] + elif related_old is not None: + related_new = objmap[related_old] + else: + related_new = None + except KeyError as e: + # Often long messages; caps makes it slightly easier to read. + log.critical( + f"rewrite_relationships: WHILE PROCESSING {oldobj = !r}, " + f"THE ATTRIBUTE ACCESSED AS oldobj.{attrname}, " + f"NAMELY {related_old = !r}, " + f"IS MISSING AS A KEY FROM {objmap = !r}. " + f"ERROR WAS: KeyError: {e}. " + f"POSSIBLE REASON: If you called via merge_db(), perhaps this " + f"is not properly handled in the 'translate_fn' function; or " + f"perhaps the tables are being dealt with in the wrong order." + ) + raise if debug: log.debug( - "rewrite_relationships: relationship {} -> {}", - attrname, - related_new, + f"rewrite_relationships: relationship " + f"{attrname} -> {related_new}" ) setattr(newobj, attrname, related_new) @@ -485,14 +502,21 @@ def deepcopy_sqla_objects( ``args``/``kwargs``, since we are copying a tree of arbitrary objects.) Args: - startobjs: SQLAlchemy ORM objects to copy - session: destination SQLAlchemy :class:`Session` into which to insert - the copies - flush: flush the session when we've finished? - debug: be verbose? - debug_walk: be extra verbose when walking the ORM tree? - debug_rewrite_rel: be extra verbose when rewriting relationships? - objmap: starting object map from source-session to destination-session + startobjs: + SQLAlchemy ORM objects to copy + session: + destination SQLAlchemy :class:`Session` into which to insert the + copies + flush: + flush the session when we've finished? + debug: + be verbose? + debug_walk: + be extra verbose when walking the ORM tree? + debug_rewrite_rel: + be extra verbose when rewriting relationships? + objmap: + starting object map from source-session to destination-session objects (see :func:`rewrite_relationships` for more detail); usually ``None`` to begin with. """ @@ -508,7 +532,7 @@ def deepcopy_sqla_objects( for startobj in startobjs: for oldobj in walk_orm_tree(startobj, seen=seen, debug=debug_walk): if debug: - log.debug("deepcopy_sqla_objects: copying {}", oldobj) + log.debug(f"deepcopy_sqla_objects: copying {oldobj}") newobj = copy_sqla_object(oldobj, omit_pk=True, omit_fk=True) # Don't insert the new object into the session here; it may trigger # an autoflush as the relationships are queried, and the new @@ -525,7 +549,7 @@ def deepcopy_sqla_objects( log.debug("deepcopy_sqla_objects: pass 2: set relationships") for oldobj, newobj in objmap.items(): if debug: - log.debug("deepcopy_sqla_objects: newobj: {}", newobj) + log.debug(f"deepcopy_sqla_objects: newobj: {newobj}") rewrite_relationships(oldobj, newobj, objmap, debug=debug_rewrite_rel) # Now we can do session insert. diff --git a/cardinal_pythonlib/sqlalchemy/schema.py b/cardinal_pythonlib/sqlalchemy/schema.py index b243386..f584119 100644 --- a/cardinal_pythonlib/sqlalchemy/schema.py +++ b/cardinal_pythonlib/sqlalchemy/schema.py @@ -61,20 +61,24 @@ Index, Table, ) -from sqlalchemy.sql import sqltypes, text from sqlalchemy.sql.ddl import DDLElement +from sqlalchemy.sql.expression import text from sqlalchemy.sql.sqltypes import ( BigInteger, Boolean, Date, DateTime, - Double, + Enum, Float, Integer, + LargeBinary, Numeric, SmallInteger, + String, Text, TypeEngine, + Unicode, + UnicodeText, ) from sqlalchemy.sql.visitors import Visitable @@ -90,11 +94,29 @@ log = get_brace_style_log_with_null_handler(__name__) +try: + from sqlalchemy.sql.sqltypes import Double +except ImportError: + # This code present to allow testing with older SQLAlchemy 1.4. + log.warning( + "Can't import sqlalchemy.sql.sqltypes.Double " + "(are you using SQLAlchemy prior to 2.0?)" + ) + Double = None + # ============================================================================= # Constants # ============================================================================= +# To avoid importing _Binary directly: +if len(LargeBinary.__bases__) != 1: + raise NotImplementedError( + "Unexpectedly, SQLAlchemy's LargeBinary class has more than one base " + "class" + ) +BinaryBaseClass = LargeBinary.__bases__[0] + VisitableType = Type[Visitable] # for SQLAlchemy 2.0 MIN_TEXT_LENGTH_FOR_FREETEXT_INDEX = 1000 @@ -110,7 +132,7 @@ "BOOLEAN": Boolean, "DATE": Date, "TIMESTAMP_NTZ": DateTime, - "DOUBLE": Double, + "DOUBLE": Double if Double is not None else Float, "FLOAT": Float, "INT": Integer, "DECIMAL": Numeric, @@ -1100,29 +1122,29 @@ def convert_sqla_type_for_dialect( # ------------------------------------------------------------------------- # Text # ------------------------------------------------------------------------- - if isinstance(coltype, sqltypes.Enum): - return sqltypes.String(length=coltype.length) - if isinstance(coltype, sqltypes.UnicodeText): + if isinstance(coltype, Enum): + return String(length=coltype.length) + if isinstance(coltype, UnicodeText): # Unbounded Unicode text. # Includes derived classes such as mssql.base.NTEXT. - return sqltypes.UnicodeText() - if isinstance(coltype, sqltypes.Text): + return UnicodeText() + if isinstance(coltype, Text): # Unbounded text, more generally. (UnicodeText inherits from Text.) # Includes sqltypes.TEXT. - return sqltypes.Text() + return Text() # Everything inheriting from String has a length property, but can be None. # There are types that can be unlimited in SQL Server, e.g. VARCHAR(MAX) # and NVARCHAR(MAX), that MySQL needs a length for. (Failure to convert # gives e.g.: 'NVARCHAR requires a length on dialect mysql'.) - if isinstance(coltype, sqltypes.Unicode): + if isinstance(coltype, Unicode): # Includes NVARCHAR(MAX) in SQL -> NVARCHAR() in SQLAlchemy. if (coltype.length is None and to_mysql) or expand_for_scrubbing: - return sqltypes.UnicodeText() + return UnicodeText() # The most general case; will pick up any other string types. - if isinstance(coltype, sqltypes.String): + if isinstance(coltype, String): # Includes VARCHAR(MAX) in SQL -> VARCHAR() in SQLAlchemy if (coltype.length is None and to_mysql) or expand_for_scrubbing: - return sqltypes.Text() + return Text() if strip_collation: return remove_collation(coltype) return coltype @@ -1168,10 +1190,9 @@ def is_sqlatype_binary(coltype: Union[TypeEngine, VisitableType]) -> bool: Is the SQLAlchemy column type a binary type? """ # Several binary types inherit internally from _Binary, making that the - # easiest to check. + # easiest to check. We obtain BinaryBaseClass (= _Binary) as above. coltype = coltype_as_typeengine(coltype) - # noinspection PyProtectedMember - return isinstance(coltype, sqltypes._Binary) + return isinstance(coltype, BinaryBaseClass) def is_sqlatype_date(coltype: Union[TypeEngine, VisitableType]) -> bool: @@ -1179,9 +1200,7 @@ def is_sqlatype_date(coltype: Union[TypeEngine, VisitableType]) -> bool: Is the SQLAlchemy column type a date type? """ coltype = coltype_as_typeengine(coltype) - return isinstance(coltype, sqltypes.DateTime) or isinstance( - coltype, sqltypes.Date - ) + return isinstance(coltype, DateTime) or isinstance(coltype, Date) def is_sqlatype_integer(coltype: Union[TypeEngine, VisitableType]) -> bool: @@ -1189,7 +1208,7 @@ def is_sqlatype_integer(coltype: Union[TypeEngine, VisitableType]) -> bool: Is the SQLAlchemy column type an integer type? """ coltype = coltype_as_typeengine(coltype) - return isinstance(coltype, sqltypes.Integer) + return isinstance(coltype, Integer) def is_sqlatype_numeric(coltype: Union[TypeEngine, VisitableType]) -> bool: @@ -1200,7 +1219,7 @@ def is_sqlatype_numeric(coltype: Union[TypeEngine, VisitableType]) -> bool: Note that integers don't count as Numeric! """ coltype = coltype_as_typeengine(coltype) - return isinstance(coltype, sqltypes.Numeric) # includes Float, Decimal + return isinstance(coltype, Numeric) # includes Float, Decimal def is_sqlatype_string(coltype: Union[TypeEngine, VisitableType]) -> bool: @@ -1208,7 +1227,7 @@ def is_sqlatype_string(coltype: Union[TypeEngine, VisitableType]) -> bool: Is the SQLAlchemy column type a string type? """ coltype = coltype_as_typeengine(coltype) - return isinstance(coltype, sqltypes.String) + return isinstance(coltype, String) def is_sqlatype_text_of_length_at_least( @@ -1220,7 +1239,7 @@ def is_sqlatype_text_of_length_at_least( length? """ coltype = coltype_as_typeengine(coltype) - if not isinstance(coltype, sqltypes.String): + if not isinstance(coltype, String): return False # not a string/text type at all if coltype.length is None: return True # string of unlimited length @@ -1260,9 +1279,9 @@ def does_sqlatype_require_index_len( https://dev.mysql.com/doc/refman/5.7/en/create-index.html.) """ coltype = coltype_as_typeengine(coltype) - if isinstance(coltype, sqltypes.Text): + if isinstance(coltype, Text): return True - if isinstance(coltype, sqltypes.LargeBinary): + if isinstance(coltype, LargeBinary): return True return False diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 029d7f0..95b4621 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -864,3 +864,10 @@ Quick links: ``pdftotext`` was unavailable. Also remove antique ``pyth`` support. And shift from unmaintained ``pdfminer`` to maintained ``pdfminer.six``. Also removed unused code around importing ``docx`` and ``docx2txt``. + +- Add some back-compatibility with SQLAlchemy 1.4+ for testing. + +- Improvements to ``merge_db``, including the option to ignore SQLAlchemy's + default table dependency order and calculate another. + +- Improve ability of Alembic support code to take a database URL.