diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3e9d2137..ec27ef86 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -9,9 +9,9 @@ on: env: # This should be the default but we'll be explicit PRE_COMMIT_HOME: ~/.caches/pre-commit - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" jobs: - the_job: + clean-check: runs-on: ubuntu-latest steps: - name: Checkout Code diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 137557a7..389e0734 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,50 +8,31 @@ on: - main env: # This should be the default but we'll be explicit - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" jobs: - the_job: + unit-tests: runs-on: ubuntu-latest - services: - postgres: - image: postgres:15 - env: - POSTGRES_PASSWORD: password - ports: - - 5432:5432 steps: - name: Checkout Code uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - name: Bootstrap poetry + - name: Start PostgreSQL shell: bash run: | - python -m ensurepip - python -m pip install --upgrade pip - python -m pip install poetry + sudo systemctl start postgresql.service + - name: Install poetry + shell: bash + run: | + sudo apt install python3-poetry - name: Configure poetry shell: bash run: | python -m poetry config virtualenvs.in-project true - # - name: Cache Poetry dependencies - # uses: actions/cache@v3 - # id: poetry-cache - # with: - # path: .venv - # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} - name: Install dependencies shell: bash if: steps.poetry-cache.outputs.cache-hit != 'true' run: | python -m poetry install --all-extras - - name: Create src database - shell: bash - run: | - PGPASSWORD=password psql --host=localhost --username=postgres --set="ON_ERROR_STOP=1" --file=tests/examples/src.dump - - name: Run Unit Tests + - name: Run tests shell: bash run: | - REQUIRES_DB=1 poetry run python -m unittest discover --verbose tests + poetry run python -m unittest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7eba811b..085ce544 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,21 +3,22 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v6.0.0 hooks: - id: trailing-whitespace + exclude: docs/(source|build/html)/_static/ - id: end-of-file-fixer exclude: docs/source/_static/ - id: check-yaml - id: check-added-large-files - repo: https://github.com/markdownlint/markdownlint # Note the "v" - rev: v0.11.0 + rev: v0.12.0 hooks: - id: markdownlint args: [--style=mdl_style.rb] - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.8.0.4 + rev: v0.11.0.1 hooks: - id: shellcheck - repo: local @@ -39,8 +40,7 @@ repos: language: system types: ['python'] exclude: (?x)( - tests/examples| - tests/workspace + tests/examples ) - id: isort name: isort @@ -49,7 +49,6 @@ repos: types: ['python'] exclude: (?x)( tests/examples| - tests/workspace| examples ) - id: pylint @@ -76,7 +75,6 @@ repos: language: system exclude: (?x)( tests/examples| - tests/workspace| examples ) types: ['python'] diff --git a/.pylintrc b/.pylintrc index d97276b9..cb276e25 100644 --- a/.pylintrc +++ b/.pylintrc @@ -24,8 +24,7 @@ ignore=CVS # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths. -ignore-paths=tests/examples, - tests/workspace +ignore-paths=tests/examples # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. @@ -53,7 +52,7 @@ persistent=yes # Min Python version to use for version dependend checks. Will default to the # version used to run pylint. -py-version=3.9 +py-version=3.10 # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 91942b1b..29cdf780 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # You can also specify other tool versions: # nodejs: "19" # rust: "1.64" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 259ebe8e..234f8fb8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,16 +47,6 @@ Executing unit tests is straightforward: python -m unittest discover --verbose tests/ ``` -for tests that are currently maintained. - -## Running functional tests - -These tests do not currently work, and will be replaced by unit tests. - -Functional tests require PostgreSQL to be installed. - - *WARNING: Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file.* - ## Building documentation locally ```bash diff --git a/datafaker/base.py b/datafaker/base.py index 17f471b5..495ff628 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -1,268 +1,26 @@ """Base table generator classes.""" +import gzip +import os +import random from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass -import functools -import math -import numpy as np -import os +from io import TextIOWrapper from pathlib import Path -import random from typing import Any import yaml -import gzip from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.schema import Table +from sqlalchemy.schema import MetaData, Table from datafaker.utils import ( + MAKE_VOCAB_PROGRESS_REPORT_EVERY, logger, stream_yaml, - MAKE_VOCAB_PROGRESS_REPORT_EVERY, table_row_count, ) -@functools.cache -def zipf_weights(size): - total = sum(map(lambda n: 1/n, range(1, size + 1))) - return [ - 1 / (n * total) - for n in range(1, size + 1) - ] - -def merge_with_constants(xs: list, constants_at: dict[int, any]): - """ - Merge a list of items with other items that must be placed at certain indices. - :param constants_at: A map of indices to objects that must be placed at - those indices. - :param xs: Items that fill in the gaps left by ``constants_at``. - :return: ``xs`` with ``constants_at`` inserted at the appropriate - points. If there are not enough elements in ``xs`` to fill in the gaps - in ``constants_at``, the elements of ``constants_at`` after the gap - are dropped. - """ - outi = 0 - xi = 0 - constant_count = len(constants_at) - while constant_count != 0: - if outi in constants_at: - yield constants_at[outi] - constant_count -= 1 - else: - if xi == len(xs): - return - yield xs[xi] - xi += 1 - outi += 1 - for x in xs[xi:]: - yield x - - -class NothingToGenerateException(Exception): - def __init__(self, message): - super().__init__(message) - - -class DistributionGenerator: - root3 = math.sqrt(3) - - def __init__(self): - self.np_gen = np.random.default_rng() - - def uniform(self, low, high) -> float: - return random.uniform(float(low), float(high)) - - def uniform_ms(self, mean, sd) -> float: - m = float(mean) - h = self.root3 * float(sd) - return random.uniform(m - h, m + h) - - def normal(self, mean, sd) -> float: - return random.normalvariate(float(mean), float(sd)) - - def lognormal(self, logmean, logsd) -> float: - return random.lognormvariate(float(logmean), float(logsd)) - - def choice(self, a): - c = random.choice(a) - return c["value"] if type(c) is dict and "value" in c else c - - def zipf_choice(self, a, n=None): - if n is None: - n = len(a) - c = random.choices(a, weights=zipf_weights(n))[0] - return c["value"] if type(c) is dict and "value" in c else c - - def weighted_choice(self, a: list[dict[str, any]]) -> list[any]: - """ - Choice weighted by the count in the original dataset. - :param a: a list of dicts, each with a ``value`` key - holding the value to be returned and a ``count`` key holding the - number of that value found in the original dataset - """ - vs = [] - counts = [] - for vc in a: - count = vc.get("count", 0) - if count: - counts.append(count) - vs.append(vc.get("value", None)) - c = random.choices(vs, weights=counts)[0] - return c - - def constant(self, value): - return value - - def multivariate_normal_np(self, cov): - rank = int(cov["rank"]) - if rank == 0: - return np.empty(shape=(0,)) - mean = [ - float(cov[f"m{i}"]) - for i in range(rank) - ] - covs = [ - [ - float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"]) - for i in range(rank) - ] - for j in range(rank) - ] - return self.np_gen.multivariate_normal(mean, covs) - - def _select_group(self, alts: list[dict[str, any]]): - """ - Choose one of the ``alts`` weighted by their ``"count"`` elements. - """ - total = 0 - for alt in alts: - if alt["count"] < 0: - logger.warning("Alternative count is %d, but should not be negative", alt["count"]) - else: - total += alt["count"] - if total == 0: - raise NothingToGenerateException("No counts in any alternative") - choice = random.randrange(total) - for alt in alts: - choice -= alt["count"] - if choice < 0: - return alt - raise Exception("Internal error: ran out of choices in _select_group") - - def _find_constants(self, result: dict[str, any]): - """ - Find all keys ``kN``, returning a dictionary of ``N: kNN``. - - This can be passed into ``merge_with_constants`` as the - ``constants_at`` argument. - """ - out: dict[int, any] = {} - for k, v in result.items(): - if k.startswith("k") and k[1:].isnumeric(): - out[int(k[1:])] = v - return out - - PERMITTED_SUBGENS = { - "multivariate_lognormal", - "multivariate_normal", - "grouped_multivariate_lognormal", - "grouped_multivariate_normal", - "constant", - "weighted_choice", - "with_constants_at", - } - - def multivariate_normal(self, cov): - """ - Produce a list of values pulled from a multivariate distribution. - - :param cov: A dict with various keys: ``rank`` is the number of - output values, ``m0``, ``m1``, ... are the means of the - distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... - are the covariates, ``cN_M`` is the covariate of the ``N``th and - ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. - :return: list of ``rank`` floating point values - """ - return self.multivariate_normal_np(cov).tolist() - - def multivariate_lognormal(self, cov): - """ - Produce a list of values pulled from a multivariate distribution. - - :param cov: A dict with various keys: ``rank`` is the number of - output values, ``m0``, ``m1``, ... are the means of the - distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... - are the covariates, ``cN_M`` is the covariate of the ``N``th and - ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. These - are all the means and covariants of the logs of the data. - :return: list of ``rank`` floating point values - """ - return np.exp(self.multivariate_normal_np(cov)).tolist() - - def grouped_multivariate_normal(self, covs): - cov = self._select_group(covs) - logger.debug("Multivariate normal group selected: %s", cov) - constants = self._find_constants(cov) - nums = self.multivariate_normal(cov) - return list(merge_with_constants(nums, constants)) - - def grouped_multivariate_lognormal(self, covs): - cov = self._select_group(covs) - logger.debug("Multivariate lognormal group selected: %s", cov) - constants = self._find_constants(cov) - nums = np.exp(self.multivariate_normal_np(cov)).tolist() - return list(merge_with_constants(nums, constants)) - - def _check_generator_name(self, name: str) -> None: - if name not in self.PERMITTED_SUBGENS: - raise Exception("%s is not a permitted generator", name) - - def alternatives(self, alternative_configs: list[dict[str, any]], counts: list[int] | None): - """ - A generator that picks between other generators. - - :param alternative_configs: List of alternative generators. - Each alternative has the following keys: "count" -- a weight for - how often to use this alternative; "name" -- which generator - for this partition, for example "composite"; "params" -- the - parameters for this alternative. - :return: list of values - """ - if counts is not None: - while True: - count = self._select_group(counts) - alt = alternative_configs[count["index"]] - name = alt["name"] - self._check_generator_name(name) - try: - return getattr(self, name)(**alt["params"]) - except NothingToGenerateException: - # Prevent this alternative from being chosen again - count["count"] = 0 - alt = self._select_group(alternative_configs) - name = alt["name"] - self._check_generator_name(name) - return getattr(self, name)(**alt["params"]) - - def with_constants_at(self, constants_at: list[int], subgen: str, params: dict[str, any]): - if subgen not in self.PERMITTED_SUBGENS: - logger.error( - "subgenerator %s is not a valid name. Valid names are %s.", - subgen, - self.PERMITTED_SUBGENS, - ) - subout = getattr(self, subgen)(**params) - logger.debug("Merging constants %s", constants_at) - return list(merge_with_constants(subout, constants_at)) - - def truncated_string(self, subgen_fn, params, length): - """ Calls ``subgen_fn(**params)`` and truncates the results to ``length``. """ - result = subgen_fn(**params) - if result is None: - return None - return result[:length] - class TableGenerator(ABC): """Abstract base class for table generator classes.""" @@ -270,7 +28,7 @@ class TableGenerator(ABC): num_rows_per_pass: int = 1 @abstractmethod - def __call__(self, dst_db_conn: Connection) -> dict[str, Any]: + def __call__(self, dst_db_conn: Connection, metadata: MetaData) -> dict[str, Any]: """Return, as a dictionary, a new row for the table that we are generating. The only argument, `dst_db_conn`, should be a database connection to the @@ -288,7 +46,9 @@ class FileUploader: table: Table - def _load_existing_file(self, connection: Connection, file_size: int, opener: Callable[[], Any]) -> None: + def _load_existing_file( + self, connection: Connection, file_size: int, opener: Callable[[], Any] + ) -> None: count = 0 with opener() as fh: rows = stream_yaml(fh) @@ -305,20 +65,32 @@ def _load_existing_file(self, connection: Connection, file_size: int, opener: Ca 100 * fh.tell() / file_size, ) - def load(self, connection: Connection, base_path: Path=Path(".")) -> None: + def load(self, connection: Connection, base_path: Path = Path(".")) -> None: """Load the data from file.""" yaml_file = base_path / Path(self.table.fullname + ".yaml") if yaml_file.exists(): - opener = lambda: open(yaml_file, mode="r", encoding="utf-8") + + def opener() -> TextIOWrapper: + return open(yaml_file, mode="r", encoding="utf-8") + else: yaml_file = base_path / Path(self.table.fullname + ".yaml.gz") if yaml_file.exists(): - opener = lambda: gzip.open(yaml_file, mode="rt") + + def opener() -> TextIOWrapper: + return gzip.open(yaml_file, mode="rt") + else: logger.warning("File %s not found. Skipping...", yaml_file) return if 0 < table_row_count(self.table, connection): - logger.warning("Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", self.table.name) + logger.warning( + ( + "Table %s already contains data" + " (consider running 'datafaker remove-vocab'), skipping..." + ), + self.table.name, + ) return try: file_size = os.path.getsize(yaml_file) @@ -331,8 +103,20 @@ def load(self, connection: Connection, base_path: Path=Path(".")) -> None: "Error inserting rows into table %s: %s", self.table.fullname, e ) + class ColumnPresence: - def sampled(self, patterns): + """Object for generators to use for missingness completely at random.""" + + def sampled(self, patterns: list[dict[str, Any]]) -> set[str]: + """ + Select a random pattern and output the non-null columns. + + :param patterns: List of outputs from missingness SQL queries. + Columns in each output: ``row_count`` is the number of rows + with this missingness pattern, then for each column + ```` there is a boolean called ``missingness__is_null``. + :return: All the names of the columns no make non-null. + """ total = 0 for pattern in patterns: total += pattern.get("row_count", 0) diff --git a/datafaker/create.py b/datafaker/create.py index d84eadb6..5cddeb2f 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,17 +1,14 @@ """Functions and classes to create and populate the target database.""" -from collections import Counter import pathlib -import random +from collections import Counter +from types import ModuleType from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple from sqlalchemy import Connection, insert, inspect from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from sqlalchemy.schema import ( - CreateSchema, - MetaData, - Table, -) +from sqlalchemy.schema import CreateSchema, MetaData, Table + from datafaker.base import FileUploader, TableGenerator from datafaker.settings import get_settings from datafaker.utils import ( @@ -56,15 +53,15 @@ def create_db_vocab( metadata: MetaData, meta_dict: dict[str, Any], config: Mapping, - base_path: pathlib.Path | None=pathlib.Path(".") -) -> int: + base_path: pathlib.Path = pathlib.Path("."), +) -> list[str]: """ Load vocabulary tables from files. - - arguments: - metadata: The schema of the database - meta_dict: The simple description of the schema from --orm-file - config: The configuration from --config-file + + :param metadata: The schema of the database + :param meta_dict: The simple description of the schema from --orm-file + :param config: The configuration from --config-file + :return: List of table names loaded. """ settings = get_settings() dst_dsn: str = settings.dst_dsn or "" @@ -85,7 +82,7 @@ def create_db_vocab( uploader = FileUploader(table=vocab_table) with Session(dst_engine) as session: session.begin() - uploader.load(session.connection(), base_path = base_path) + uploader.load(session.connection(), base_path=base_path) session.commit() tables_loaded.append(vocab_table_name) except IntegrityError: @@ -101,9 +98,9 @@ def create_db_vocab( def create_db_data( sorted_tables: Sequence[Table], - table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], + df_module: ModuleType, num_passes: int, + metadata: MetaData, ) -> RowCounts: """Connect to a database and populate it with data.""" settings = get_settings() @@ -112,25 +109,37 @@ def create_db_data( return create_db_data_into( sorted_tables, - table_generator_dict, - story_generator_list, + df_module, num_passes, dst_dsn, settings.dst_schema, + metadata, ) +# pylint: disable=too-many-arguments too-many-positional-arguments def create_db_data_into( sorted_tables: Sequence[Table], - table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], + df_module: ModuleType, num_passes: int, db_dsn: str, schema_name: str | None, + metadata: MetaData, ) -> RowCounts: - dst_engine = get_sync_engine( - create_db_engine(db_dsn, schema_name=schema_name) - ) + """ + Populate the database. + + :param sorted_tables: The table names to populate, sorted so that foreign + keys' targets are populated before the foreign keys themselves. + :param table_generator_dict: A mapping of table names to the generators + used to make data for them. + :param story_generator_list: A list of story generators to be run after the + table generators on each pass. + :param num_passes: Number of passes to perform. + :param db_dsn: Connection string for the destination database. + :param schema_name: Destination schema name. + """ + dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) row_counts: Counter[str] = Counter() with dst_engine.connect() as dst_conn: @@ -138,23 +147,31 @@ def create_db_data_into( row_counts += populate( dst_conn, sorted_tables, - table_generator_dict, - story_generator_list, + df_module.table_generator_dict, + df_module.story_generator_list, + metadata, ) return row_counts +# pylint: disable=too-many-instance-attributes class StoryIterator: - def __init__(self, + """Iterates through all the rows produced by all the stories.""" + + def __init__( + self, stories: Iterable[tuple[str, Story]], table_dict: Mapping[str, Table], table_generator_dict: Mapping[str, TableGenerator], dst_conn: Connection, ): + """Initialise a Story Iterator.""" self._stories: Iterator[tuple[str, Story]] = iter(stories) self._table_dict: Mapping[str, Table] = table_dict self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict self._dst_conn: Connection = dst_conn + self._table_name: str | None + self._final_values: dict[str, Any] | None = None try: name, self._story = next(self._stories) logger.info("Generating data for story '%s'", name) @@ -164,32 +181,38 @@ def __init__(self, def is_ended(self) -> bool: """ - Do we have another row to process? + Check if we have another row to process. + If so, insert() can be called. """ return self._table_name is None - def has_table(self, table_name: str): - """ - Do we have a row for table table_name? - """ + def has_table(self, table_name: str) -> bool: + """Check if we have a row for table ``table_name``.""" return table_name == self._table_name - def table_name(self) -> str: + def table_name(self) -> str | None: """ - The name of the current table (or None if no more stories to process) + Get the name of the current table. + + :return: The table name, or None if there are no more stories + to process. """ return self._table_name - def insert(self) -> None: + def insert(self, metadata: MetaData) -> None: """ - Perform the insert. Call this after __init__ or next, and after checking - that is_ended returns False. + Put the row in the table. + + Call this after __init__ or next, and after checking that is_ended + returns False. """ + if self._table_name is None: + raise StopIteration("StoryIterator.insert after is_ended") table = self._table_dict[self._table_name] if table.name in self._table_generator_dict: table_generator = self._table_generator_dict[table.name] - default_values = table_generator(self._dst_conn, random.random) + default_values = table_generator(self._dst_conn, metadata) else: default_values = {} insert_values = {**default_values, **self._provided_values} @@ -210,17 +233,16 @@ def insert(self) -> None: cursor.close() def next(self) -> None: - """ - Advance to the next table row. - """ + """Advance to the next row.""" while True: try: if self._final_values is None: self._table_name, self._provided_values = next(self._story) return - else: - self._table_name, self._provided_values = self._story.send(self._final_values) - return + self._table_name, self._provided_values = self._story.send( + self._final_values + ) + return except StopIteration: try: name, self._story = next(self._stories) @@ -236,6 +258,7 @@ def populate( tables: Sequence[Table], table_generator_dict: Mapping[str, TableGenerator], story_generator_list: Sequence[Mapping[str, Any]], + metadata: MetaData, ) -> RowCounts: """Populate a database schema with synthetic data.""" row_counts: Counter[str] = Counter() @@ -260,7 +283,7 @@ def populate( for table in tables: # Do we have a story row to enter into this table? if story_iterator.has_table(table.name): - story_iterator.insert() + story_iterator.insert(metadata) row_counts[table.name] = row_counts.get(table.name, 0) + 1 story_iterator.next() if table.name not in table_generator_dict: @@ -274,7 +297,7 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values(table_generator(dst_conn, random.random)) + stmt = insert(table).values(table_generator(dst_conn, metadata)) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() @@ -284,8 +307,12 @@ def populate( # Insert any remaining stories while not story_iterator.is_ended(): - story_iterator.insert() + story_iterator.insert(metadata) t = story_iterator.table_name() + if t is None: + raise AssertionError( + "Internal error: story iterator returns None but not is_ended" + ) row_counts[t] = row_counts.get(t, 0) + 1 story_iterator.next() diff --git a/datafaker/dump.py b/datafaker/dump.py index c4d2b24f..2307ba41 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,29 +1,31 @@ +"""Data dumping functions.""" import csv import io +from typing import TYPE_CHECKING + import sqlalchemy from sqlalchemy.schema import MetaData -from datafaker.settings import get_settings -from datafaker.utils import ( - create_db_engine, - get_sync_engine, - logger, -) +from datafaker.utils import create_db_engine, get_sync_engine, logger + +if TYPE_CHECKING: + from _csv import Writer -def _make_csv_writer(file): +def _make_csv_writer(file: io.TextIOBase) -> "Writer": + """Make the standard CSV file writer.""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) def dump_db_tables( - metadata: MetaData, - dsn: str, - schema: str | None, - table_name: str, - file: io.TextIOBase + metadata: MetaData, + dsn: str, + schema: str | None, + table_name: str, + file: io.TextIOBase, ) -> None: - """ Output the table as CSV. """ - if table_name not in metadata.tables: + """Output the table as CSV.""" + if table_name not in metadata.tables: logger.error("%s is not a table described in the ORM file", table_name) return table = metadata.tables[table_name] @@ -33,4 +35,4 @@ def dump_db_tables( with engine.connect() as connection: result = connection.execute(sqlalchemy.select(table)) for row in result: - csv_out.writerow(row._tuple()) + csv_out.writerow(row) diff --git a/datafaker/generators.py b/datafaker/generators.py deleted file mode 100644 index e5deb91c..00000000 --- a/datafaker/generators.py +++ /dev/null @@ -1,1729 +0,0 @@ -""" -Generator factories for making generators for single columns. -""" - -from abc import ABC, abstractmethod -from collections.abc import Mapping -from dataclasses import dataclass -import decimal -from functools import lru_cache -from itertools import chain, combinations -import math -import mimesis -import mimesis.locales -import re -import sqlalchemy -from sqlalchemy import Column, Engine, text, Connection, RowMapping, Sequence -from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time -from typing import Callable, Iterable, TypeVar - -from datafaker.base import DistributionGenerator -from datafaker.utils import logger - -# How many distinct values can we have before we consider a -# choice distribution to be infeasible? -MAXIMUM_CHOICES = 500 - -dist_gen = DistributionGenerator() -generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) - -class Generator(ABC): - """ - Random data generator. - - A generator is specific to a particular column in a particular table in - a particluar database. - - A generator knows how to fetch its summary data from the database, how to calculate - its fit (if apropriate) and which function actually does the generation. - - It also knows these summary statistics for the column it was instantiated on, - and therefore knows how to generate fake data for that column. - """ - @abstractmethod - def function_name(self) -> str: - """ The name of the generator function to put into df.py. """ - - def name(self) -> str: - """ - The name of the generator. - - Usually the same as the function name, but can be different to distinguish - between generators that have the same function but different queries. - """ - return self.function_name() - - @abstractmethod - def nominal_kwargs(self) -> dict[str, str]: - """ - The kwargs the generator wants to be called with. - The values will tend to be references to something in the src-stats.yaml - file. - For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will - provide the value stored in src-stats.yaml as - SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument - to the generator function. - """ - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """ - SQL clauses to add to a SELECT ... FROM {table} query. - - Will add to SRC_STATS["auto__{table}"] - For example { - "count": { - "clause": "COUNT(*)", - "comment": "number of rows in table {table}" - }, "avg_thiscolumn": { - "clause": "AVG(thiscolumn)", - "comment": "Average value of thiscolumn in table {table}" - }} - will make the clause become: - "SELECT COUNT(*) AS count, AVG(thiscolumn) AS avg_thiscolumn FROM thistable" - and this will populate SRC_STATS["auto__thistable"]["results"][0]["count"] and - SRC_STATS["auto__thistable"]["results"][0]["avg_thiscolumn"] in the src-stats.yaml file. - """ - return {} - - def custom_queries(self) -> dict[str, dict[str, str]]: - """ - SQL queries to add to SRC_STATS. - - Should be used for queries that do not follow the SELECT ... FROM table format - using aggregate queries, because these should use select_aggregate_clauses. - - For example {"myquery": { - "query": "SELECT one, too AS two FROM mytable WHERE too > 1", - "comment": "big enough one and two from table mytable" - }} - will populate SRC_STATS["myquery"]["results"][0]["one"] and SRC_STATS["myquery"]["results"][0]["two"] - in the src-stats.yaml file. - - Keys should be chosen to minimize the chances of clashing with other queries, - for example "auto__{table}__{column}__{queryname}" - """ - return {} - - @abstractmethod - def actual_kwargs(self) -> dict[str, any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ - - @abstractmethod - def generate_data(self, count) -> list[any]: - """ - Generate 'count' random data points for this column. - """ - - def fit(self, default=None) -> float | None: - """ - Return a value representing how well the distribution fits the real source data. - - 0.0 means "perfectly". - Returns default if no fitness has been defined. - """ - return default - - -class PredefinedGenerator(Generator): - """ - Generator built from an existing config.yaml. - """ - SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") - AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") - SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') - - def _get_src_stats_mentioned(self, val) -> set[str]: - if not val: - return set() - if type(val) is str: - ss = self.SRC_STAT_NAME_RE.match(val) - if ss: - ss_name = ss.group(1) - logger.debug("Found SRC_STATS reference %s", ss_name) - return set([ss_name]) - else: - logger.debug("Value %s does not seem to be a SRC_STATS reference", val) - return set() - if type(val) is list: - return set.union(*(self._get_src_stats_mentioned(v) for v in val)) - if type(val) is dict: - return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) - return set() - - def __init__(self, table_name: str, generator_object: Mapping[str, any], config: Mapping[str, any]): - """ - Initialise a generator from a config.yaml. - :param config: The entire configuration. - :param generator_object: The part of the configuration at tables.*.row_generators - """ - logger.debug("Creating a PredefinedGenerator %s from table %s", generator_object["name"], table_name) - self._table_name = table_name - self._name: str = generator_object["name"] - self._kwn: dict[str, str] = generator_object.get("kwargs", {}) - self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) - # Need to deal with this somehow (or remove it from the schema) - self._argn: list[str] = generator_object.get("args", []) - self._select_aggregate_clauses = {} - self._custom_queries = {} - for sstat in config.get("src-stats", []): - name: str = sstat["name"] - dpq = sstat.get("dp-query", None) - query = sstat.get("query", dpq) #... should we really be combining query and dp-query? - comments = sstat.get("comments", []) - if name in self._src_stats_mentioned: - logger.debug("Found a src-stats entry for %s", name) - # This query is one that this generator is interested in - sam = None if query is None else self.SELECT_AGGREGATE_RE.match(query) - # sam.group(2) is the table name from the FROM clause of the query - if sam and name == f"auto__{sam.group(2)}": - # name is auto__{table_name}, so it's a select_aggregate, so we split up its clauses - sacs = [ - self.AS_CLAUSE_RE.match(clause) - for clause in sam.group(1).split(',') - ] - # Work out what select_aggregate_clauses this represents - for sac in sacs: - if sac is not None: - comment = comments.pop() if comments else None - self._select_aggregate_clauses[sac.group(2)] = { - "clause": sac.group(1), - "comment": comment, - } - else: - # some other name, so must be a custom query - logger.debug("Custom query %s is '%s'", name, query) - self._custom_queries[name] = { - "query": query, - "comment": comments[0] if comments else None, - } - - def function_name(self) -> str: - return self._name - - def nominal_kwargs(self) -> dict[str, str]: - return self._kwn - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - return self._select_aggregate_clauses - - def custom_queries(self) -> dict[str, dict[str, str]]: - return self._custom_queries - - def actual_kwargs(self) -> dict[str, any]: - # Run the queries from nominal_kwargs - #... - logger.error("PredefinedGenerator.actual_kwargs not implemented yet") - return {} - - def generate_data(self, count) -> list[any]: - # Call the function if we can. This could be tricky... - #... - logger.error("PredefinedGenerator.generate_data not implemented yet") - return [] - - -class GeneratorFactory(ABC): - """ - A factory for making generators appropriate for a database column. - """ - @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: - """ - Returns all the generators that might be appropriate for this column. - """ - - -class Buckets: - """ - Finds the real distribution of continuous data so that we can measure - the fit of generators against it. - """ - def __init__(self, engine: Engine, table_name: str, column_name: str, mean:float, stddev: float, count: int): - with engine.connect() as connection: - raw_buckets = connection.execute(text( - "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( - column=column_name, table=table_name, x=mean - 2 * stddev, w = stddev / 2 - ) - )) - self.buckets = [0] * 10 - for rb in raw_buckets: - if rb.b is not None: - bucket = min(9, max(0, int(rb.b) + 1)) - self.buckets[bucket] += rb.f / count - self.mean = mean - self.stddev = stddev - - @classmethod - def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): - """ - Construct a Buckets object. - - Calculates the mean and standard deviation of the values in the column - specified and makes ten buckets, centered on the mean and each half - a standard deviation wide (except for the end two that extend to - infinity). Each bucket will be set to the count of the number of values - in the column within that bucket. - """ - with engine.connect() as connection: - result = connection.execute( - text("SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( - table=table_name, - column=column_name, - )) - ).first() - if result is None or result.stddev is None or result.count < 2: - return None - try: - buckets = Buckets( - engine, - table_name, - column_name, - result.mean, - result.stddev, - result.count, - ) - except sqlalchemy.exc.DatabaseError as exc: - logger.debug("Failed to instantiate Buckets object: %s", exc) - return None - return buckets - - def fit_from_counts(self, bucket_counts: list[float]) -> float: - """ - Figure out the fit from bucket counts from the generator distribution. - """ - return fit_from_buckets(self.buckets, bucket_counts) - - def fit_from_values(self, values: list[float]) -> float: - """ - Figure out the fit from samples from the generator distribution. - """ - buckets = [0] * 10 - x = self.mean - 2 * self.stddev - w = self.stddev / 2 - for v in values: - b = min(9, max(0, int((v - x)/w))) - buckets[b] += 1 - return self.fit_from_counts(buckets) - - -class MultiGeneratorFactory(GeneratorFactory): - """ A composite factory. """ - def __init__(self, factories: list[GeneratorFactory]): - super().__init__() - self.factories = factories - - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: - return [ - generator - for factory in self.factories - for generator in factory.get_generators(columns, engine) - ] - - -class MimesisGeneratorBase(Generator): - def __init__( - self, - function_name: str, - ): - """ - Generator from Mimesis. - - :param function_name: is relative to 'generic', for example 'person.name'. - """ - super().__init__() - f = generic - for part in function_name.split("."): - if not hasattr(f, part): - raise Exception(f"Mimesis does not have a function {function_name}: {part} not found") - f = getattr(f, part) - if not callable(f): - raise Exception(f"Mimesis object {function_name} is not a callable, so cannot be used as a generator") - self._name = "generic." + function_name - self._generator_function = f - def function_name(self): - return self._name - def generate_data(self, count): - return [ - self._generator_function() - for _ in range(count) - ] - - -class MimesisGenerator(MimesisGeneratorBase): - def __init__( - self, - function_name: str, - value_fn: Callable[[any], float] | None=None, - buckets: Buckets | None=None, - ): - """ - Generator from Mimesis. - - :param function_name: is relative to 'generic', for example 'person.name'. - :param value_fn: Function to convert generator output to floats, if needed. The values - thus produced are compared against the buckets to estimate the fit. - :param buckets: The distribution of string lengths in the real data. If this is None - then the fit method will return None. - """ - super().__init__(function_name) - if buckets is None: - self._fit = None - return - samples = self.generate_data(400) - if value_fn: - samples = [ - value_fn(s) - for s in samples - ] - self._fit = buckets.fit_from_values(samples) - def function_name(self): - return self._name - def nominal_kwargs(self): - return {} - def actual_kwargs(self): - return {} - def fit(self, default=None): - return default if self._fit is None else self._fit - - -class MimesisGeneratorTruncated(MimesisGenerator): - def __init__( - self, - function_name: str, - length: int, - value_fn: Callable[[any], float] | None=None, - buckets: Buckets | None=None, - ): - self._length = length - super().__init__(function_name, value_fn, buckets) - def function_name(self): - return "dist_gen.truncated_string" - def name(self): - return f"{self._name} [truncated to {self._length}]" - def nominal_kwargs(self): - return { - "subgen_fn": self._name, - "params": {}, - "length": self._length, - } - def actual_kwargs(self): - return { - "subgen_fn": self._name, - "params": {}, - "length": self._length, - } - def generate_data(self, count): - return [ - self._generator_function()[:self._length] - for _ in range(count) - ] - - -class MimesisDateTimeGenerator(MimesisGeneratorBase): - def __init__(self, column: Column, function_name: str, min_year: str, max_year: str, start: int, end: int): - """ - :param column: The column to generate into - :param function_name: The name of the mimesis function - :param min_year: SQL expression extracting the minimum year - :param min_year: SQL expression extracting the maximum year - :param start: The actual first year found - :param end: The actual last year found - """ - super().__init__(function_name) - self._column = column - self._max_year = max_year - self._min_year = min_year - self._start = start - self._end = end - - @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str): - extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" - max_year = f"MAX({extract_year})" - min_year = f"MIN({extract_year})" - with engine.connect() as connection: - result = connection.execute( - text(f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}") - ).first() - if result is None or result.start is None or result.end is None: - return [] - return [MimesisDateTimeGenerator( - column, - function_name, - min_year, - max_year, - int(result.start), - int(result.end), - )] - def nominal_kwargs(self): - return { - "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', - "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', - } - def actual_kwargs(self): - return { - "start": self._start, - "end": self._end, - } - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - return { - f"{self._column.name}__start": { - "clause": self._min_year, - "comment": f"Earliest year found for column {self._column.name} in table {self._column.table.name}", - }, - f"{self._column.name}__end": { - "clause": self._max_year, - "comment": f"Latest year found for column {self._column.name} in table {self._column.table.name}", - }, - } - def generate_data(self, count): - return [ - self._generator_function(start=self._start, end=self._end) - for _ in range(count) - ] - - -def get_column_type(column: Column): - try: - return column.type.as_generic() - except NotImplementedError: - return column.type - - -class MimesisStringGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return strings. - """ - GENERATOR_NAMES = [ - "address.calling_code", - "address.city", - "address.continent", - "address.country", - "address.country_code", - "address.postal_code", - "address.province", - "address.street_number", - "address.street_name", - "address.street_suffix", - "person.blood_type", - "person.email", - "person.first_name", - "person.last_name", - "person.full_name", - "person.gender", - "person.language", - "person.nationality", - "person.occupation", - "person.password", - "person.title", - "person.university", - "person.username", - "person.worldview", - "text.answer", - "text.color", - "text.level", - "text.quote", - "text.sentence", - "text.text", - "text.word", - ] - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - column_type = get_column_type(column) - if not isinstance(column_type, String): - return [] - try: - buckets = Buckets.make_buckets( - engine, - column.table.name, - f"LENGTH({column.name})", - ) - fitness_fn = len - except Exception as exc: - # Some column types that appear to be strings (such as enums) - # cannot have their lengths measured. In this case we cannot - # detect fitness using lengths. - buckets = None - fitness_fn = None - length = column_type.length - if length: - return list(map( - lambda gen: MimesisGeneratorTruncated(gen, length, fitness_fn, buckets), - self.GENERATOR_NAMES, - )) - return list(map( - lambda gen: MimesisGenerator(gen, fitness_fn, buckets), - self.GENERATOR_NAMES, - )) - - -class MimesisFloatGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return floating point numbers. - """ - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - if not isinstance(get_column_type(column), Numeric): - return [] - return list(map(MimesisGenerator, [ - "person.height", - ])) - - -class MimesisDateGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return dates. - """ - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Date): - return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") - - -class MimesisDateTimeGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return datetimes. - """ - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, DateTime): - return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.datetime") - - -class MimesisTimeGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return times. - """ - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Time): - return [] - return [MimesisGenerator("datetime.time")] - - -class MimesisIntegerGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return integers. - """ - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - return [MimesisGenerator("person.weight")] - - -def fit_from_buckets(xs: list[float], ys: list[float]): - sum_diff_squared = sum(map(lambda t, a: (t - a)*(t - a), xs, ys)) - count = len(ys) - return sum_diff_squared / (count * count) - - -class ContinuousDistributionGenerator(Generator): - def __init__(self, table_name: str, column_name: str, buckets: Buckets): - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.buckets = buckets - def nominal_kwargs(self): - return { - "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', - "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', - } - def actual_kwargs(self): - if self.buckets is None: - return {} - return { - "mean": self.buckets.mean, - "sd": self.buckets.stddev, - } - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - clauses = super().select_aggregate_clauses() - return { - **clauses, - f"mean__{self.column_name}": { - "clause": f"AVG({self.column_name})", - "comment": f"Mean of {self.column_name} from table {self.table_name}", - }, - f"stddev__{self.column_name}": { - "clause": f"STDDEV({self.column_name})", - "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", - }, - } - def fit(self, default=None): - if self.buckets is None: - return default - return self.buckets.fit_from_counts(self.expected_buckets) - - -class GaussianGenerator(ContinuousDistributionGenerator): - expected_buckets = [0.0227, 0.0441, 0.0918, 0.1499, 0.1915, 0.1915, 0.1499, 0.0918, 0.0441, 0.0227] - def function_name(self): - return "dist_gen.normal" - def generate_data(self, count): - return [ - dist_gen.normal(self.buckets.mean, self.buckets.stddev) - for _ in range(count) - ] - - -class UniformGenerator(ContinuousDistributionGenerator): - expected_buckets = [0, 0.06698, 0.14434, 0.14434, 0.14434, 0.14434, 0.14434, 0.14434, 0.06698, 0] - def function_name(self): - return "dist_gen.uniform_ms" - def generate_data(self, count): - return [ - dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) - for _ in range(count) - ] - - -class ContinuousDistributionGeneratorFactory(GeneratorFactory): - """ - All generators that want an average and standard deviation. - """ - def _get_generators_from_buckets( - self, - _engine: Engine, - table_name: str, - column_name: str, - buckets: Buckets, - ) -> list[Generator]: - return [ - GaussianGenerator(table_name, column_name, buckets), - UniformGenerator(table_name, column_name, buckets), - ] - - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - column_name = column.name - table_name = column.table.name - buckets = Buckets.make_buckets(engine, table_name, column_name) - if buckets is None: - return [] - return self._get_generators_from_buckets(engine, table_name, column_name, buckets) - - -class LogNormalGenerator(Generator): - #TODO: figure out the real buckets here (this was from a random sample in R) - expected_buckets = [0, 0, 0, 0.28627, 0.40607, 0.14937, 0.06735, 0.03492, 0.01918, 0.03684] - def __init__(self, table_name: str, column_name: str, buckets: Buckets, logmean: float, logstddev: float): - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.buckets = buckets - self.logmean = logmean - self.logstddev = logstddev - def function_name(self): - return "dist_gen.lognormal" - def generate_data(self, count): - return [ - dist_gen.lognormal(self.logmean, self.logstddev) - for _ in range(count) - ] - def nominal_kwargs(self): - return { - "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', - "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', - } - def actual_kwargs(self): - return { - "logmean": self.logmean, - "logsd": self.logstddev, - } - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - clauses = super().select_aggregate_clauses() - return { - **clauses, - f"logmean__{self.column_name}": { - "clause": f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name}) ELSE NULL END)", - "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", - }, - f"logstddev__{self.column_name}": { - "clause": f"STDDEV(CASE WHEN 0<{self.column_name} THEN LN({self.column_name}) ELSE NULL END)", - "comment": f"Standard deviation of logs of {self.column_name} from table {self.table_name}", - }, - } - def fit(self, default=None): - if self.buckets is None: - return default - return self.buckets.fit_from_counts(self.expected_buckets) - - -class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): - """ - All generators that want an average and standard deviation of log data. - """ - def _get_generators_from_buckets( - self, - engine: Engine, - table_name: str, - column_name: str, - buckets: Buckets, - ) -> list[Generator]: - with engine.connect() as connection: - result = connection.execute( - text("SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( - table=table_name, - column=column_name, - )) - ).first() - if result is None or result.logstddev is None: - return [] - return [ - LogNormalGenerator( - table_name, - column_name, - buckets, - float(result.logmean), - float(result.logstddev), - ) - ] - - -def zipf_distribution(total, bins): - basic_dist = list(map(lambda n: 1/n, range(1, bins + 1))) - bd_remaining = sum(basic_dist) - for b in basic_dist: - # yield b/bd_remaining of the `total` remaining - if bd_remaining == 0: - yield 0 - else: - x = math.floor(0.5 + total * b / bd_remaining) - bd_remaining -= x * bd_remaining / total - total -= x - yield x - - -class ChoiceGenerator(Generator): - STORE_COUNTS = False - def __init__( - self, - table_name, - column_name, - values, - counts, - sample_count = None, - suppress_count = 0, - ): - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.values = values - estimated_counts = self.get_estimated_counts(counts) - self._fit = fit_from_buckets(counts, estimated_counts) - - extra_results = "" - extra_expo = "" - extra_comment = "" - if self.STORE_COUNTS: - extra_results = f", COUNT({column_name}) AS count" - extra_expo = ", count" - extra_comment = " and their counts" - if suppress_count == 0: - if sample_count is None: - self._query = f"SELECT {column_name} AS value{extra_results} FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY COUNT({column_name}) DESC" - self._comment = f"All the values{extra_comment} that appear in column {column_name} of table {table_name}" - self._annotation = None - else: - self._query = f"SELECT {column_name} AS value{extra_results} FROM (SELECT {column_name} FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" - self._comment = f"The values{extra_comment} that appear in column {column_name} of a random sample of {sample_count} rows of table {table_name}" - self._annotation = "sampled" - else: - if sample_count is None: - self._query = f"SELECT value{extra_expo} FROM (SELECT {column_name} AS value, COUNT({column_name}) AS count FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count" - self._comment = f"All the values{extra_comment} that appear in column {column_name} of table {table_name} more than {suppress_count} times" - self._annotation = "suppressed" - else: - self._query = f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM (SELECT {column_name} AS value FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count" - self._comment = f"The values{extra_comment} that appear more than {suppress_count} times in column {column_name}, out of a random sample of {sample_count} rows of table {table_name}" - self._annotation = "sampled and suppressed" - - @abstractmethod - def get_estimated_counts(counts): - """ - The counts that we would expect if this distribution was the correct one. - """ - def nominal_kwargs(self): - return { - "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', - } - def name(self): - n = super().name() - if self._annotation is None: - return n - return f"{n} [{self._annotation}]" - def actual_kwargs(self): - return { - "a": self.values, - } - def custom_queries(self) -> dict[str, dict[str, str]]: - qs = super().custom_queries() - return { - **qs, - f"auto__{self.table_name}__{self.column_name}": { - "query": self._query, - "comment": self._comment, - } - } - def fit(self, default=None): - return default if self._fit is None else self._fit - -class ZipfChoiceGenerator(ChoiceGenerator): - def get_estimated_counts(self, counts): - return list(zipf_distribution(sum(counts), len(counts))) - def function_name(self): - return "dist_gen.zipf_choice" - def generate_data(self, count): - return [ - dist_gen.zipf_choice(self.values, len(self.values)) - for _ in range(count) - ] - - -def uniform_distribution(total, bins): - p = total // bins - n = total % bins - for _ in range(0, n): - yield p + 1 - for _ in range(n, bins): - yield p - - -class UniformChoiceGenerator(ChoiceGenerator): - def get_estimated_counts(self, counts): - return list(uniform_distribution(sum(counts), len(counts))) - def function_name(self): - return "dist_gen.choice" - def generate_data(self, count): - return [ - dist_gen.choice(self.values) - for _ in range(count) - ] - - -class WeightedChoiceGenerator(ChoiceGenerator): - STORE_COUNTS = True - def get_estimated_counts(self, counts): - return counts - def function_name(self): - return "dist_gen.weighted_choice" - def generate_data(self, count): - return [ - dist_gen.weighted_choice(self.values) - for _ in range(count) - ] - - -class ChoiceGeneratorFactory(GeneratorFactory): - """ - All generators that want an average and standard deviation. - """ - SAMPLE_COUNT = MAXIMUM_CHOICES - SUPPRESS_COUNT = 7 - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - column_name = column.name - table_name = column.table.name - generators = [] - with engine.connect() as connection: - results = connection.execute( - text("SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( - table=table_name, - column=column_name, - limit=MAXIMUM_CHOICES+1, - )) - ) - if results is not None and results.rowcount <= MAXIMUM_CHOICES: - values = [] # The values found - counts = [] # The number or each value - cvs: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if counts: - generators += [ - ZipfChoiceGenerator(table_name, column_name, values, counts), - UniformChoiceGenerator(table_name, column_name, values, counts), - WeightedChoiceGenerator(table_name, column_name, cvs, counts), - ] - results = connection.execute( - text("SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( - table=table_name, - column=column_name, - sample_count=self.SAMPLE_COUNT, - )) - ) - if results is not None: - values = [] # All values found - counts = [] # The number or each value - cvs: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times - counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if self.SUPPRESS_COUNT < c: - counts_not_suppressed.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values_not_suppressed.append(v) - cvs_not_suppressed.append({"value": v, "count": c}) - if counts: - generators += [ - ZipfChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT), - UniformChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT), - WeightedChoiceGenerator(table_name, column_name, cvs, counts, sample_count=self.SAMPLE_COUNT), - ] - if counts_not_suppressed: - generators += [ - ZipfChoiceGenerator( - table_name, - column_name, - values_not_suppressed, - counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - UniformChoiceGenerator( - table_name, - column_name, - values_not_suppressed, - counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - WeightedChoiceGenerator( - table_name=table_name, - column_name=column_name, - values=cvs_not_suppressed, - counts=counts, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - ] - return generators - - -class ConstantGenerator(Generator): - def __init__(self, value): - super().__init__() - self.value = value - self.repr = repr(value) - def function_name(self) -> str: - return "dist_gen.constant" - def nominal_kwargs(self) -> dict[str, str]: - return {"value": self.repr} - def actual_kwargs(self) -> dict[str, any]: - return {"value": self.value} - def generate_data(self, count) -> list[any]: - return [self.value for _ in range(count)] - - -class ConstantGeneratorFactory(GeneratorFactory): - """ - Just the null generator - """ - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) != 1: - return [] - column = columns[0] - if column.nullable: - return [ConstantGenerator(None)] - c_type = get_column_type(column) - if isinstance(c_type, String): - return [ConstantGenerator("")] - if isinstance(c_type, Numeric): - return [ConstantGenerator(0.0)] - if isinstance(c_type, Integer): - return [ConstantGenerator(0)] - return [] - - -class MultivariateNormalGenerator(Generator): - def __init__( - self, - table_name: list[str], - column_names: list[str], - query: str, - covariates: dict[str, float], - function_name: str, - ): - self._table = table_name - self._columns = column_names - self._query = query - self._covariates = covariates - self._function_name = function_name - - def function_name(self): - return "dist_gen." + self._function_name - - def nominal_kwargs(self): - return { - "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', - } - - def custom_queries(self): - cols = ", ".join(self._columns) - return { - f"auto__cov__{self._table}": { - "comment": f"Means and covariate matrix for the columns {cols}, so that we can produce the relatedness between these in the fake data.", - "query": self._query, - } - } - - def actual_kwargs(self) -> dict[str, any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ - return { "cov": self._covariates } - - def generate_data(self, count) -> list[any]: - """ - Generate 'count' random data points for this column. - """ - return [ - getattr(dist_gen, self._function_name)(self._covariates) - for _ in range(count) - ] - - def fit(self, default=None) -> float | None: - return default - - -class MultivariateNormalGeneratorFactory(GeneratorFactory): - def function_name(self) -> str: - return "multivariate_normal" - - def query_predicate(self, column: Column) -> str: - return column.name + " IS NOT NULL" - - def query_var(self, column: str) -> str: - return column - - def query( - self, - table: str, - columns: list[Column], - predicates: list[str]=[], - group_by_clause: str="", - constant_clauses: str="", - constants: str="", - suppress_count: int=1, - sample_count: int | None=None, - ) -> str: - """ - Gets a query for the basics for multivariate normal/lognormal parameters. - :param table: The name of the table to be queried. - :param columns: The columns in the multivariate distribution. - :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. - :param group_by_clause: Any GROUP BY clause (starting with " GROUP BY " if not ""). - :param constant_clauses: Extra output columns in the outer SELECT clause, such - as ", _q.column_one AS k1, _q.column_two AS k2". Note the initial comma. - :param constants: Extra output columns in the inner SELECT clause. Used to - deliver columns to the outer select, such as ", column_one, column_two". - Note the initial comma. - :param suppress_count: a group smaller than this will be suppressed. - :param sample_count: this many samples will be taken from each partition. - """ - preds = [self.query_predicate(col) for col in columns] + predicates - where = " WHERE " + " AND ".join(preds) if preds else "" - avgs = "".join( - f", AVG({self.query_var(col.name)}) AS m{i}" - for i, col in enumerate(columns) - ) - multiples = "".join( - f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" - for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[:iy+1]) - ) - means = "".join( - f", _q.m{i}" for i in range(len(columns)) - ) - covs = "".join( - f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" - for iy in range(len(columns)) - for ix in range(iy+1) - ) - if sample_count is None: - subquery = table + where - else: - subquery = f"(SELECT * FROM {table}{where} ORDER BY RANDOM() LIMIT {sample_count}) AS _sampled" - # if there are any numeric columns we need at least two rows to make any (co)variances at all - suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" - return ( - f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" - f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" - f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" - ) - - def get_generators(self, columns: list[Column], engine: Engine): - # For the case of one column we'll use GaussianGenerator - if len(columns) < 2: - return [] - # All columns must be numeric - for c in columns: - ct = get_column_type(c) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - column_names = [c.name for c in columns] - table = columns[0].table.name - query = self.query(table, columns) - with engine.connect() as connection: - try: - covariates = connection.execute(text( - query - )).mappings().first() - except Exception as e: - logger.debug("SQL query %s failed with error %s", query, e) - return [] - if not covariates or covariates["c0_0"] is None: - return [] - return [MultivariateNormalGenerator( - table, - column_names, - query, - covariates, - self.function_name(), - )] - - -class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): - def function_name(self) -> str: - return "multivariate_lognormal" - - def query_predicate(self, column: Column) -> str: - return f"COALESCE(0 < {column.name}, FALSE)" - - def query_var(self, column: str) -> str: - return f"LN({column})" - - -def text_list(items: list[str]) -> str: - """ - Concatenate the items with commas and one "and". - """ - if not hasattr(items, "__getitem__"): - items = list(items) - if len(items) == 0: - return "" - if len(items) == 1: - return items[0] - return ", ".join(items[:-1]) + " and " + items[-1] - - -@dataclass -class RowPartition: - query: str - # list of numeric columns - included_numeric: list[Column] - # map of indices to column names that are being grouped by. - # The indices are indices of where they need to be inserted into - # the generator outputs. - included_choice: dict[int, str] - # map of column names to clause that defines the partition - # such as "mycolumn IS NULL" - excluded_columns: dict[str, str] - # map of constant outputs that need to be inserted into the - # list of included column values (so once the generator has - # been run and the included_choice values have been - # added): {index: value} - constant_outputs: dict[int, any] - # The actual covariates from the source database - covariates: dict[str, float] - - def comment(self) -> str: - caveat = "" - if self.included_choice: - caveat = f" (for each possible value of {text_list(self.included_choice.values())})" - if not self.included_numeric: - return f"Number of rows for which {text_list(self.excluded_columns.values())}{caveat}" - if not self.excluded_columns: - where = "" - else: - where = f" where {text_list(self.excluded_columns.values())}" - if len(self.included_numeric) == 1: - return f"Mean and variance for column {self.included_numeric[0].name}{where}." - return ( - "Means and covariate matrix for the columns " - f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" - " produce the relatedness between these in the fake data." - ) - - -class NullPartitionedNormalGenerator(Generator): - """ - A generator of mixed numeric and non-numeric data. - - Generates data that matches the source data in - missingness, choice of non-numeric data and numeric - data. - - For the numeric data to be generated, samples of rows for each - combination of non-numeric values and missingness. If any such - combination has only one line in the source data (or sample of - the source data if sampling), it will not be generated as a - covariate matrix cannot be generated from one source row - (although if the data is all non-numeric values and nulls, single - rows are used because no covariate matrix is required for this). - """ - def __init__( - self, - query_name: str, - partitions: dict[int, RowPartition], - function_name: str="grouped_multivariate_lognormal", - name_suffix: str | None=None, - partition_count_query: str | None=None, - partition_counts: Sequence[RowMapping] | None=None, - partition_count_comment: str | None=None, - ): - self._query_name = query_name - self._partitions = partitions - self._function_name = function_name - self._partition_count_query = partition_count_query - self._partition_counts = [dict(pc) for pc in partition_counts] - self._partition_count_comment = partition_count_comment - if name_suffix: - self._name = f"null-partitioned {function_name} [{name_suffix}]" - else: - self._name = f"null-partitioned {function_name}" - - def name(self): - return self._name - - def function_name(self): - return "dist_gen.alternatives" - - def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition): - count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' - if not partition.included_numeric and not partition.included_choice: - return { - "count": count, - "name": '"constant"', - "params": {"value": [None] * len(partition.constant_outputs)}, - } - covariates = { - "covs": f'SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"]' - } - if not partition.constant_outputs: - return { - "count": count, - "name": f'"{self._function_name}"', - "params": covariates, - } - return { - "count": count, - "name": '"with_constants_at"', - "params": { - "constants_at": partition.constant_outputs, - "subgen": f'"{self._function_name}"', - "params": covariates, - } - } - - def _count_query_name(self): - if self._partition_count_query: - return f"auto__cov__{self._query_name}__counts" - return None - - def nominal_kwargs(self): - return { - "alternative_configs": [ - self._nominal_kwargs_with_combinations(index, self._partitions[index]) - for index in range(len(self._partitions)) - ], - "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', - } - - def custom_queries(self): - partitions = { - f"auto__cov__{self._query_name}__alt_{index}": { - "comment": partition.comment(), - "query": partition.query, - } - for index, partition in self._partitions.items() - } - if not self._partition_count_query: - return partitions - return { - self._count_query_name(): { - "comment": self._partition_count_comment, - "query": self._partition_count_query, - }, - **partitions, - } - - def _actual_kwargs_with_combinations(self, partition: RowPartition): - count = sum(row["count"] for row in partition.covariates) - if not partition.included_numeric and not partition.included_choice: - return { - "count": count, - "name": "constant", - "params": {"value": [None] * len(partition.excluded_columns)}, - } - if not partition.excluded_columns: - return { - "count": count, - "name": self._function_name, - "params": { - "covs": partition.covariates, - } - } - return { - "count": count, - "name": "with_constants_at", - "params": { - "constants_at": partition.constant_outputs, - "subgen": self._function_name, - "params": { - "covs": partition.covariates, - }, - } - } - - def actual_kwargs(self) -> dict[str, any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ - return { - "alternative_configs": [ - self._actual_kwargs_with_combinations(self._partitions[index]) - for index in range(len(self._partitions)) - ], - "counts": self._partition_counts, - } - - def generate_data(self, count) -> list[any]: - """ - Generate 'count' random data points for this column. - """ - kwargs = self.actual_kwargs() - return [ - dist_gen.alternatives(**kwargs) - for _ in range(count) - ] - - def fit(self, default=None) -> float | None: - return default - - -def is_numeric(col: Column) -> bool: - ct = get_column_type(col) - return ( - isinstance(ct, Numeric) or isinstance(ct, Integer) - ) and not col.foreign_keys - -T = TypeVar('T') - -def powerset(input: Iterable[T]) -> Iterable[Iterable[T]]: - """Returns a list of all sublists of""" - return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) - - -@dataclass -class NullableColumn: - """ - A reference to a nullable column whose nullability is part of a partitioning. - """ - column: Column - # The bit (power of two) of the number of the partition in the partition sizes list - bitmask: int - - -class NullPatternPartition: - """ - The definition of a partition (in other words, what makes it not another partition) - """ - def __init__( - self, - columns: Iterable[Column], - partition_nonnulls: Iterable[NullableColumn] - ): - self.index = sum(nc.bitmask for nc in partition_nonnulls) - nonnull_columns = { nc.column.name for nc in partition_nonnulls } - self.included_numeric: list[Column] = [] - self.included_choice: dict[int, str] = {} - self.group_by_clause = "" - self.constant_clauses = "" - self.constants = "" - self.excluded: dict[str, str] = {} - self.predicates: list[str] = [] - self.nones: dict[int, None] = {} - for col_index, column in enumerate(columns): - col_name = column.name - if col_name in nonnull_columns or not column.nullable: - if is_numeric(column): - self.included_numeric.append(column) - else: - index = len(self.included_numeric) + len(self.included_choice) - self.included_choice[index] = col_name - if self.group_by_clause: - self.group_by_clause += ", " + col_name - else: - self.group_by_clause = " GROUP BY " + col_name - self.constant_clauses += f", _q.{col_name} AS k{index}" - self.constants += ", " + col_name - else: - self.excluded[col_name] = f"{col_name} IS NULL" - self.predicates.append(f"{col_name} IS NULL") - self.nones[col_index] = None - - -class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): - SAMPLE_COUNT = MAXIMUM_CHOICES - SUPPRESS_COUNT = 7 - - def function_name(self) -> str: - return "grouped_multivariate_normal" - - def query_predicate(self, column: Column) -> str: - """ - Returns a SQL expression that is true when ``column`` is available for analysis. - """ - if is_numeric(column): - # x <> x + 1 ensures that x is not infinity or NaN - return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" - return f"{column.name} IS NOT NULL" - - def query_var(self, column: str) -> str: - return column - - def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: - """ - Gets a list of nullable columns together with bitmasks. - """ - out: list[NullableColumn] = [] - for col in columns: - if col.nullable: - out.append(NullableColumn( - column=col, - bitmask=2 ** len(out), - )) - return out - - def get_partition_count_query(self, ncs: list[NullableColumn], table: str, where: str | None=None) -> str: - """ - Returns a SQL expression returning columns ``count`` and ``index``. - - Each row returned represents one of the null pattern partitions. - ``index`` is the bitmask of all those nullable columns that are not null for - this partition, and ``count`` is the total number of rows in this partition. - """ - index_exp = " + ".join( - f"CASE WHEN {self.query_predicate(nc.column)} THEN {nc.bitmask} ELSE 0 END" - for nc in ncs - ) - if where is None: - return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' - return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - - def get_generators(self, columns: list[Column], engine: Engine): - if len(columns) < 2: - return [] - nullable_columns = self.get_nullable_columns(columns) - if not nullable_columns: - return [] - table = columns[0].table.name - query_name = f"{table}__{columns[0].name}" - # Partitions for minimal suppression and no sampling - row_partitions_maximal: dict[int, RowPartition] = {} - # Partitions for minimal suppression but sampling - row_partitions_sampled: dict[int, RowPartition] = {} - # Partitions for normal suppression and severe sampling - row_partitions_ss: dict[int, RowPartition] = {} - for partition_nonnulls in powerset(nullable_columns): - partition_def = NullPatternPartition(columns, partition_nonnulls) - query_all = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, - constant_clauses=partition_def.constant_clauses, - ) - row_partitions_maximal[partition_def.index] = RowPartition( - query_all, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - {}, - ) - query_sampled = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, - constant_clauses=partition_def.constant_clauses, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_sampled[partition_def.index] = RowPartition( - query_sampled, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - {}, - ) - query_ss = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, - constant_clauses=partition_def.constant_clauses, - suppress_count=self.SUPPRESS_COUNT, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_ss[partition_def.index] = RowPartition( - query_ss, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - {}, - ) - gens = [] - try: - with engine.connect() as connection: - partition_query_max = self.get_partition_count_query(nullable_columns, table) - partition_count_max_results = connection.execute( - text(partition_query_max) - ).mappings().fetchall() - count_comment = f"Number of rows for each combination of the columns { {nc.column.name for nc in nullable_columns} } of the table {table} being null" - if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - )) - if self._execute_partition_queries(connection, row_partitions_sampled): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_sampled, - self.function_name(), - name_suffix="sampled", - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - )) - partition_query_ss = self.get_partition_count_query( - nullable_columns, - table, - where=f"WHERE {self.SUPPRESS_COUNT} < count" - ) - partition_count_ss_results = connection.execute( - text(partition_query_ss) - ).mappings().fetchall() - if self._execute_partition_queries(connection, row_partitions_ss): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - )) - except sqlalchemy.exc.DatabaseError as exc: - logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) - return [] - return gens - - def _execute_partition_queries( - self, - connection: Connection, - partitions: dict[int, RowPartition], - ): - """ - Execute the query in each partition, filling in the covariates. - :return: True if all the partitions work, False if any of them fail. - """ - found_nonzero = False - for rp in partitions.values(): - rp.covariates = connection.execute(text( - rp.query - )).mappings().fetchall() - if not rp.covariates or rp.covariates[0]["count"] is None: - rp.covariates = [{"count": 0}] - else: - found_nonzero = True - return found_nonzero - - -class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): - def function_name(self) -> str: - return "grouped_multivariate_lognormal" - - def query_predicate(self, column: Column) -> str: - if is_numeric(column): - # x <> x + 1 ensures that x is not infinity or NaN - return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" - return f"{column.name} IS NOT NULL" - - def query_var(self, column: str) -> str: - return f"LN({column})" - - -@lru_cache(1) -def everything_factory(): - return MultiGeneratorFactory([ - MimesisStringGeneratorFactory(), - MimesisIntegerGeneratorFactory(), - MimesisFloatGeneratorFactory(), - MimesisDateGeneratorFactory(), - MimesisDateTimeGeneratorFactory(), - MimesisTimeGeneratorFactory(), - ContinuousDistributionGeneratorFactory(), - ContinuousLogDistributionGeneratorFactory(), - ChoiceGeneratorFactory(), - ConstantGeneratorFactory(), - MultivariateNormalGeneratorFactory(), - MultivariateLogNormalGeneratorFactory(), - NullPartitionedNormalGeneratorFactory(), - NullPartitionedLogNormalGeneratorFactory(), - ]) diff --git a/datafaker/generators/__init__.py b/datafaker/generators/__init__.py new file mode 100644 index 00000000..650c8ba3 --- /dev/null +++ b/datafaker/generators/__init__.py @@ -0,0 +1,53 @@ +"""Generators write generator function definitions and queries into config.yaml.""" + +from functools import lru_cache + +from datafaker.generators.base import ( + ConstantGeneratorFactory, + GeneratorFactory, + MultiGeneratorFactory, +) +from datafaker.generators.choice import ChoiceGeneratorFactory +from datafaker.generators.continuous import ( + ContinuousDistributionGeneratorFactory, + ContinuousLogDistributionGeneratorFactory, + MultivariateLogNormalGeneratorFactory, + MultivariateNormalGeneratorFactory, +) +from datafaker.generators.mimesis import ( + MimesisDateGeneratorFactory, + MimesisDateTimeGeneratorFactory, + MimesisFloatGeneratorFactory, + MimesisIntegerGeneratorFactory, + MimesisStringGeneratorFactory, + MimesisTimeGeneratorFactory, +) +from datafaker.generators.partitioned import ( + NullPartitionedLogNormalGeneratorFactory, + NullPartitionedNormalGeneratorFactory, +) + + +# Using a cache instead of just initializing an object to avoid +# startup time being spent when it isn't needed. +@lru_cache(1) +def everything_factory() -> GeneratorFactory: + """Get a factory that encapsulates all the other factories.""" + return MultiGeneratorFactory( + [ + MimesisStringGeneratorFactory(), + MimesisIntegerGeneratorFactory(), + MimesisFloatGeneratorFactory(), + MimesisDateGeneratorFactory(), + MimesisDateTimeGeneratorFactory(), + MimesisTimeGeneratorFactory(), + ContinuousDistributionGeneratorFactory(), + ContinuousLogDistributionGeneratorFactory(), + ChoiceGeneratorFactory(), + ConstantGeneratorFactory(), + MultivariateNormalGeneratorFactory(), + MultivariateLogNormalGeneratorFactory(), + NullPartitionedNormalGeneratorFactory(), + NullPartitionedLogNormalGeneratorFactory(), + ] + ) diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py new file mode 100644 index 00000000..aba91b60 --- /dev/null +++ b/datafaker/generators/base.py @@ -0,0 +1,422 @@ +"""Basic Generators and factories.""" + +import re +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any, Sequence, Union + +import mimesis +import mimesis.locales +import sqlalchemy +from sqlalchemy import Column, Engine, text +from sqlalchemy.types import Integer, Numeric, String, TypeEngine +from typing_extensions import Self + +from datafaker.providers import DistributionProvider +from datafaker.utils import logger + +NumericType = Union[int, float] + + +dist_gen = DistributionProvider() +generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + + +class GeneratorError(Exception): + """Error thrown from Datafaker Generators.""" + + +class Generator(ABC): + """ + Random data generator. + + A generator is specific to a particular column in a particular table in + a particluar database. + + A generator knows how to fetch its summary data from the database, how to calculate + its fit (if apropriate) and which function actually does the generation. + + It also knows these summary statistics for the column it was instantiated on, + and therefore knows how to generate fake data for that column. + """ + + @abstractmethod + def function_name(self) -> str: + """Get the name of the generator function to put into df.py.""" + + def name(self) -> str: + """ + Get the name of the generator. + + Usually the same as the function name, but can be different to distinguish + between generators that have the same function but different queries. + """ + return self.function_name() + + @abstractmethod + def nominal_kwargs(self) -> dict[str, str]: + """ + Get the kwargs the generator wants to be called with. + + The values will tend to be references to something in the src-stats.yaml + file. + For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will + provide the value stored in src-stats.yaml as + SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument + to the generator function. + """ + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """ + Get the SQL clauses to add to a SELECT ... FROM {table} query. + + Will add to SRC_STATS["auto__{table}"] + For example { + "count": { + "clause": "COUNT(*)", + "comment": "number of rows in table {table}" + }, "avg_thiscolumn": { + "clause": "AVG(thiscolumn)", + "comment": "Average value of thiscolumn in table {table}" + }} + will make the clause become: + "SELECT COUNT(*) AS count, AVG(thiscolumn) AS avg_thiscolumn FROM thistable" + and this will populate SRC_STATS["auto__thistable"]["results"][0]["count"] and + SRC_STATS["auto__thistable"]["results"][0]["avg_thiscolumn"] in the src-stats.yaml file. + """ + return {} + + def custom_queries(self) -> dict[str, dict[str, str]]: + """ + Get the SQL queries to add to SRC_STATS. + + Should be used for queries that do not follow the SELECT ... FROM table format + using aggregate queries, because these should use select_aggregate_clauses. + + For example {"myquery": { + "query": "SELECT one, too AS two FROM mytable WHERE too > 1", + "comment": "big enough one and two from table mytable" + }} + will populate SRC_STATS["myquery"]["results"][0]["one"] + and SRC_STATS["myquery"]["results"][0]["two"] + in the src-stats.yaml file. + + Keys should be chosen to minimize the chances of clashing with other queries, + for example "auto__{table}__{column}__{queryname}" + """ + return {} + + @abstractmethod + def actual_kwargs(self) -> dict[str, Any]: + """ + Get the kwargs (summary statistics) this generator is instantiated with. + + This must match `nominal_kwargs` in structure. + """ + + @abstractmethod + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + + def fit(self, default: float = -1) -> float: + """ + Return a value representing how well the distribution fits the real source data. + + 0.0 means "perfectly". + Returns default if no fitness has been defined. + """ + return default + + +class PredefinedGenerator(Generator): + """Generator built from an existing config.yaml.""" + + SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") + AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") + SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') + + def _get_src_stats_mentioned(self, val: Any) -> set[str]: + if not val: + return set() + if isinstance(val, str): + ss = self.SRC_STAT_NAME_RE.match(val) + if ss: + ss_name = ss.group(1) + logger.debug("Found SRC_STATS reference %s", ss_name) + return set([ss_name]) + logger.debug("Value %s does not seem to be a SRC_STATS reference", val) + return set() + if isinstance(val, list): + return set.union(*(self._get_src_stats_mentioned(v) for v in val)) + if isinstance(val, dict): + return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) + return set() + + def __init__( + self, + table_name: str, + generator_object: Mapping[str, Any], + config: Mapping[str, Any], + ): + """ + Initialise a generator from a config.yaml. + + :param config: The entire configuration. + :param generator_object: The part of the configuration at tables.*.row_generators + """ + logger.debug( + "Creating a PredefinedGenerator %s from table %s", + generator_object["name"], + table_name, + ) + self._table_name = table_name + self._name: str = generator_object["name"] + self._kwn: dict[str, str] = generator_object.get("kwargs", {}) + self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) + # Need to deal with this somehow (or remove it from the schema) + self._argn: list[str] = generator_object.get("args", []) + self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} + self._custom_queries = {} + for sstat in config.get("src-stats", []): + name: str = sstat["name"] + dpq = sstat.get("dp-query", None) + query = sstat.get( + "query", dpq + ) # ... should we really be combining query and dp-query? + comments = sstat.get("comments", []) + if name in self._src_stats_mentioned: + logger.debug("Found a src-stats entry for %s", name) + # This query is one that this generator is interested in + sam = None if query is None else self.SELECT_AGGREGATE_RE.match(query) + # sam.group(2) is the table name from the FROM clause of the query + if sam and name == f"auto__{sam.group(2)}": + # name is auto__{table_name}, so it's a select_aggregate, + # so we split up its clauses + sacs = [ + self.AS_CLAUSE_RE.match(clause) + for clause in sam.group(1).split(",") + ] + # Work out what select_aggregate_clauses this represents + for sac in sacs: + if sac is not None: + comment = comments.pop() if comments else None + self._select_aggregate_clauses[sac.group(2)] = { + "clause": sac.group(1), + "comment": comment, + } + else: + # some other name, so must be a custom query + logger.debug("Custom query %s is '%s'", name, query) + self._custom_queries[name] = { + "query": query, + "comment": comments[0] if comments else None, + } + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" + return self._kwn + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + return self._select_aggregate_clauses + + def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" + return self._custom_queries + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + # Run the queries from nominal_kwargs + # ... + logger.error("PredefinedGenerator.actual_kwargs not implemented yet") + return {} + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + # Call the function if we can. This could be tricky... + # ... + logger.error("PredefinedGenerator.generate_data not implemented yet") + return [] + + +class GeneratorFactory(ABC): + """A factory for making generators appropriate for a database column.""" + + @abstractmethod + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + + +def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: + """Calculate the fit by comparing a pair of lists of buckets.""" + sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) + count = len(ys) + return sum_diff_squared / (count * count) + + +class Buckets: + """ + Measured buckets for a real distribution. + + Finds the real distribution of continuous data so that we can measure + the fit of generators against it. + """ + + # pylint: disable=too-many-arguments too-many-positional-arguments + def __init__( + self, + engine: Engine, + table_name: str, + column_name: str, + mean: float, + stddev: float, + count: int, + ): + """Initialise a Buckets object.""" + with engine.connect() as connection: + raw_buckets = connection.execute( + text( + f"SELECT COUNT({column_name}) AS f," + f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" + f" FROM {table_name} GROUP BY b" + ) + ) + self.buckets: Sequence[int] = [0] * 10 + for rb in raw_buckets: + if rb.b is not None: + bucket = min(9, max(0, int(rb.b) + 1)) + self.buckets[bucket] += rb.f / count + self.mean = mean + self.stddev = stddev + + @classmethod + def make_buckets( + cls, engine: Engine, table_name: str, column_name: str + ) -> Self | None: + """ + Construct a Buckets object. + + Calculates the mean and standard deviation of the values in the column + specified and makes ten buckets, centered on the mean and each half + a standard deviation wide (except for the end two that extend to + infinity). Each bucket will be set to the count of the number of values + in the column within that bucket. + """ + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT AVG({column_name}) AS mean," + f" STDDEV({column_name}) AS stddev," + f" COUNT({column_name}) AS count FROM {table_name}" + ) + ).first() + if result is None or result.stddev is None or getattr(result, "count") < 2: + return None + try: + buckets = cls( + engine, + table_name, + column_name, + result.mean, + result.stddev, + getattr(result, "count"), + ) + except sqlalchemy.exc.DatabaseError as exc: + logger.debug("Failed to instantiate Buckets object: %s", exc) + return None + return buckets + + def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: + """Figure out the fit from bucket counts from the generator distribution.""" + return fit_from_buckets(self.buckets, bucket_counts) + + def fit_from_values(self, values: list[float]) -> float: + """Figure out the fit from samples from the generator distribution.""" + buckets = [0] * 10 + x = self.mean - 2 * self.stddev + w = self.stddev / 2 + for v in values: + b = min(9, max(0, int((v - x) / w))) + buckets[b] += 1 + return self.fit_from_counts(buckets) + + +class MultiGeneratorFactory(GeneratorFactory): + """A composite factory.""" + + def __init__(self, factories: list[GeneratorFactory]): + """Initialise a MultiGeneratorFactory.""" + super().__init__() + self.factories = factories + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + return [ + generator + for factory in self.factories + for generator in factory.get_generators(columns, engine) + ] + + +def get_column_type(column: Column) -> TypeEngine: + """Get the type of the column, generic if possible.""" + try: + return column.type.as_generic() + except NotImplementedError: + return column.type + + +class ConstantGenerator(Generator): + """Generator that always produces the same value.""" + + def __init__(self, value: Any) -> None: + """Initialise the ConstantGenerator.""" + super().__init__() + self.value = value + self.repr = repr(value) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.constant" + + def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" + return {"value": self.repr} + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {"value": self.value} + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self.value for _ in range(count)] + + +class ConstantGeneratorFactory(GeneratorFactory): + """Just the null generator.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate for these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + if column.nullable: + return [ConstantGenerator(None)] + c_type = get_column_type(column) + if isinstance(c_type, String): + return [ConstantGenerator("")] + if isinstance(c_type, Numeric): + return [ConstantGenerator(0.0)] + if isinstance(c_type, Integer): + return [ConstantGenerator(0)] + return [] diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py new file mode 100644 index 00000000..aa1d838d --- /dev/null +++ b/datafaker/generators/choice.py @@ -0,0 +1,413 @@ +"""Generator factories for making generators for choices of values.""" + +import decimal +import math +import typing +from abc import abstractmethod +from typing import Any, Sequence, Union + +from sqlalchemy import Column, CursorResult, Engine, text + +from datafaker.generators.base import ( + Generator, + GeneratorFactory, + dist_gen, + fit_from_buckets, +) + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + + +def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Get a zipf distribution for a certain number of items. + + :param total: The total number of items to be distributed. + :param bins: The total number of bins to distribute the items into. + :return: A generator of the number of items in each bin, from the + largest to the smallest. + """ + basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) + bd_remaining = sum(basic_dist) + for b in basic_dist: + # yield b/bd_remaining of the `total` remaining + if bd_remaining == 0: + yield 0 + else: + x = math.floor(0.5 + total * b / bd_remaining) + bd_remaining -= x * bd_remaining / total + total -= x + yield x + + +class ChoiceGenerator(Generator): + """Base generator for all generators producing choices of items.""" + + STORE_COUNTS = False + + # pylint: disable=too-many-arguments too-many-positional-arguments + def __init__( + self, + table_name: str, + column_name: str, + values: list[Any], + counts: list[int], + sample_count: int | None = None, + suppress_count: int = 0, + ) -> None: + """Initialise a ChoiceGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.values = values + estimated_counts = self.get_estimated_counts(counts) + self._fit = fit_from_buckets(counts, estimated_counts) + + extra_results = "" + extra_expo = "" + extra_comment = "" + if self.STORE_COUNTS: + extra_results = f", COUNT({column_name}) AS count" + extra_expo = ", count" + extra_comment = " and their counts" + if suppress_count == 0: + if sample_count is None: + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL GROUP BY value" + f" ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name}" + ) + self._annotation = None + else: + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM" + f" (SELECT {column_name} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL" + f" ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"The values{extra_comment} that appear in column {column_name}" + f" of a random sample of {sample_count} rows of table {table_name}" + ) + self._annotation = "sampled" + else: + if sample_count is None: + self._query = ( + f"SELECT value{extra_expo} FROM" + f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" + f" FROM {table_name} WHERE {column_name} IS NOT NULL" + f" GROUP BY value ORDER BY count DESC) AS _inner" + f" WHERE {suppress_count} < count" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name} more than {suppress_count} times" + ) + self._annotation = "suppressed" + else: + self._query = ( + f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" + f" (SELECT {column_name} AS value FROM {table_name}" + f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY count DESC)" + f" AS _inner WHERE {suppress_count} < count" + ) + self._comment = ( + f"The values{extra_comment} that appear more than {suppress_count} times" + f" in column {column_name}, out of a random sample of {sample_count} rows" + f" of table {table_name}" + ) + self._annotation = "sampled and suppressed" + + @abstractmethod + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', + } + + def name(self) -> str: + """Get the name of the generator.""" + n = super().name() + if self._annotation is None: + return n + return f"{n} [{self._annotation}]" + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "a": self.values, + } + + def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" + qs = super().custom_queries() + return { + **qs, + f"auto__{self.table_name}__{self.column_name}": { + "query": self._query, + "comment": self._comment, + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default if self._fit is None else self._fit + + +class ZipfChoiceGenerator(ChoiceGenerator): + """Generator producing items in a Zipf distribution.""" + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return list(zipf_distribution(sum(counts), len(counts))) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.zipf_choice" + + def generate_data(self, count: int) -> list[float]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.zipf_choice_direct(self.values, len(self.values)) + for _ in range(count) + ] + + +def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Construct a distribution putting ``total`` items uniformly into ``bins`` bins. + + If they don't fit exactly evenly, the earlier bins will have one more + item than the later bins so the total is as required. + """ + p = total // bins + n = total % bins + for _ in range(0, n): + yield p + 1 + for _ in range(n, bins): + yield p + + +class UniformChoiceGenerator(ChoiceGenerator): + """A generator producing values, each roughly as frequently as each other.""" + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return list(uniform_distribution(sum(counts), len(counts))) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.choice" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.choice_direct(self.values) for _ in range(count)] + + +class WeightedChoiceGenerator(ChoiceGenerator): + """Choice generator that matches the source data's frequency.""" + + STORE_COUNTS = True + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return counts + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.weighted_choice" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.weighted_choice(self.values) for _ in range(count)] + + +class ValueGatherer: + """ + Gathers values from a query of values and counts. + + The query must return columns ``v`` for a value and ``f`` for the + count of how many of those values there are. + These values will be gathered into a number of properties: + ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts + in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and + ``count`` giving these values and counts. ``counts_not_suppressed``, + ``values_not_suppressed`` and ``cvs_not_suppressed`` are the + equivalents with the counts less than or equal to ``suppress_count`` + removed. + + :param suppress_count: value with a count of this or fewer will be excluded + from the suppressed values. + """ + + def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: + """Initialise a ValueGatherer.""" + values = [] # All values found + counts = [] # The number or each value + cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times + counts_not_suppressed = [] # The number for each value not suppressed + cvs_not_suppressed: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + for result in results: + c = result.f + if c != 0: + counts.append(c) + v = result.v + if isinstance(v, decimal.Decimal): + v = float(v) + values.append(v) + cvs.append({"value": v, "count": c}) + if suppress_count < c: + counts_not_suppressed.append(c) + v = result.v + if isinstance(v, decimal.Decimal): + v = float(v) + values_not_suppressed.append(v) + cvs_not_suppressed.append({"value": v, "count": c}) + self.values = values + self.counts = counts + self.cvs = cvs + self.values_not_suppressed = values_not_suppressed + self.counts_not_suppressed = counts_not_suppressed + self.cvs_not_suppressed = cvs_not_suppressed + + +class ChoiceGeneratorFactory(GeneratorFactory): + """All generators that want an average and standard deviation.""" + + SAMPLE_COUNT = MAXIMUM_CHOICES + SUPPRESS_COUNT = 7 + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + column_name = column.name + table_name = column.table.name + generators = [] + with engine.connect() as connection: + results = connection.execute( + text( + f'SELECT "{column_name}" AS v, COUNT("{column_name}")' + f' AS f FROM "{table_name}" GROUP BY v' + f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" + ) + ) + if results is not None and results.rowcount <= MAXIMUM_CHOICES: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), + ] + if vg.counts_not_suppressed: + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + suppress_count=self.SUPPRESS_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + suppress_count=self.SUPPRESS_COUNT, + ), + WeightedChoiceGenerator( + table_name=table_name, + column_name=column_name, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, + suppress_count=self.SUPPRESS_COUNT, + ), + ] + sampled_results = connection.execute( + text( + f"SELECT v, COUNT(v) AS f FROM" + f' (SELECT "{column_name}" as v FROM "{table_name}"' + f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" + f" AS _inner GROUP BY v ORDER BY f DESC" + ) + ) + if sampled_results is not None: + vg = ValueGatherer(sampled_results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + WeightedChoiceGenerator( + table_name, + column_name, + vg.cvs, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + ] + if vg.counts_not_suppressed: + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + WeightedChoiceGenerator( + table_name=table_name, + column_name=column_name, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + ] + return generators diff --git a/datafaker/generators/continuous.py b/datafaker/generators/continuous.py new file mode 100644 index 00000000..fc50c7fe --- /dev/null +++ b/datafaker/generators/continuous.py @@ -0,0 +1,509 @@ +"""Generator factories for making generators of continuous distributions.""" + +import itertools +from collections.abc import Iterable, Sequence +from typing import Any + +from sqlalchemy import Column, Engine, RowMapping, text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.types import Integer, Numeric + +from datafaker.generators.base import ( + Buckets, + Generator, + GeneratorFactory, + NumericType, + dist_gen, + get_column_type, +) +from datafaker.utils import Empty, logger + + +class ContinuousDistributionGenerator(Generator): + """Base class for generators producing continuous distributions.""" + + expected_buckets: Sequence[NumericType] = [] + + def __init__(self, table_name: str, column_name: str, buckets: Buckets): + """Initialise a ContinuousDistributionGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.buckets = buckets + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "mean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["mean__{self.column_name}"]' + ), + "sd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["stddev__{self.column_name}"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + if self.buckets is None: + return {} + return { + "mean": self.buckets.mean, + "sd": self.buckets.stddev, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + clauses = super().select_aggregate_clauses() + return { + **clauses, + f"mean__{self.column_name}": { + "clause": f"AVG({self.column_name})", + "comment": f"Mean of {self.column_name} from table {self.table_name}", + }, + f"stddev__{self.column_name}": { + "clause": f"STDDEV({self.column_name})", + "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + if self.buckets is None: + return default + return self.buckets.fit_from_counts(self.expected_buckets) + + +class GaussianGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a Gaussian (normal) distribution.""" + + expected_buckets = [ + 0.0227, + 0.0441, + 0.0918, + 0.1499, + 0.1915, + 0.1915, + 0.1499, + 0.0918, + 0.0441, + 0.0227, + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.normal" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.normal(self.buckets.mean, self.buckets.stddev) + for _ in range(count) + ] + + +class UniformGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a uniform distribution.""" + + expected_buckets = [ + 0, + 0.06698, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.06698, + 0, + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.uniform_ms" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) + for _ in range(count) + ] + + +class ContinuousDistributionGeneratorFactory(GeneratorFactory): + """All generators that want an average and standard deviation.""" + + def _get_generators_from_buckets( + self, + _engine: Engine, + table_name: str, + column_name: str, + buckets: Buckets, + ) -> Sequence[Generator]: + return [ + GaussianGenerator(table_name, column_name, buckets), + UniformGenerator(table_name, column_name, buckets), + ] + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + column_name = column.name + table_name = column.table.name + buckets = Buckets.make_buckets(engine, table_name, column_name) + if buckets is None: + return [] + return self._get_generators_from_buckets( + engine, table_name, column_name, buckets + ) + + +class LogNormalGenerator(Generator): + """Generator producing numbers in a log-normal distribution.""" + + # R: + # > xs<-seq(-2,2,0.5)*sqrt((exp(1)-1)*exp(1))+exp(0.5) + # > ys <- plnorm(xs) + # > c(ys, 1) - c(0,ys) + # [1] 0.00000000 0.00000000 0.00000000 0.28589471 0.40556775 0.15086088 + # [7] 0.06716451 0.03428958 0.01924848 0.03697409 + expected_buckets = [ + 0, + 0, + 0, + 0.28589471, + 0.40556775, + 0.15086088, + 0.06716451, + 0.03428958, + 0.01924848, + 0.03697409, + ] + + # pylint: disable=too-many-arguments too-many-positional-arguments + def __init__( + self, + table_name: str, + column_name: str, + buckets: Buckets, + logmean: float, + logstddev: float, + ): + """Initialise a LogNormalGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.buckets = buckets + self.logmean = logmean + self.logstddev = logstddev + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.lognormal" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "logmean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logmean__{self.column_name}"]' + ), + "logsd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logstddev__{self.column_name}"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "logmean": self.logmean, + "logsd": self.logstddev, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + clauses = super().select_aggregate_clauses() + return { + **clauses, + f"logmean__{self.column_name}": { + "clause": ( + f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" + " ELSE NULL END)" + ), + "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", + }, + f"logstddev__{self.column_name}": { + "clause": ( + f"STDDEV(CASE WHEN 0<{self.column_name}" + f" THEN LN({self.column_name}) ELSE NULL END)" + ), + "comment": ( + f"Standard deviation of logs of {self.column_name}" + f" from table {self.table_name}" + ), + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + if self.buckets is None: + return default + return self.buckets.fit_from_counts(self.expected_buckets) + + +class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): + """All generators that want an average and standard deviation of log data.""" + + def _get_generators_from_buckets( + self, + engine: Engine, + table_name: str, + column_name: str, + buckets: Buckets, + ) -> Sequence[Generator]: + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" + " ELSE NULL END) AS logmean," + f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" + f" AS logstddev FROM {table_name}" + ) + ).first() + if result is None or result.logstddev is None: + return [] + return [ + LogNormalGenerator( + table_name, + column_name, + buckets, + float(result.logmean), + float(result.logstddev), + ) + ] + + +class MultivariateNormalGenerator(Generator): + """Generator of multiple values drawn from a multivariate normal distribution.""" + + # pylint: disable=too-many-arguments too-many-positional-arguments + def __init__( + self, + table_name: str, + column_names: list[str], + query: str, + covariates: RowMapping, + function_name: str, + ) -> None: + """Initialise a MultivariateNormalGenerator.""" + self._table = table_name + self._columns = column_names + self._query = query + self._covariates = covariates + self._function_name = function_name + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen." + self._function_name + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', + } + + def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" + cols = ", ".join(self._columns) + return { + f"auto__cov__{self._table}": { + "comment": ( + f"Means and covariate matrix for the columns {cols}," + " so that we can produce the relatedness between these in the fake data." + ), + "query": self._query, + } + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {"cov": self._covariates} + + def generate_data(self, count: int) -> list[Any]: + """Generate 'count' random data points for this column.""" + return [ + getattr(dist_gen, self._function_name)(self._covariates) + for _ in range(count) + ] + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default + + +class MultivariateNormalGeneratorFactory(GeneratorFactory): + """Normal distribution generator factory.""" + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "multivariate_normal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" + return column.name + " IS NOT NULL" + + def query_var(self, column: str) -> str: + """Get the SQL expression of the value to query for this column.""" + return column + + # pylint: disable=too-many-arguments too-many-positional-arguments + def query( + self, + table: str, + columns: Sequence[Column], + predicates: Iterable[str] = Empty.iterable(), + group_by_clause: str = "", + constant_clauses: str = "", + constants: str = "", + suppress_count: int = 1, + sample_count: int | None = None, + ) -> str: + """ + Get a query for the basics for multivariate normal/lognormal parameters. + + :param table: The name of the table to be queried. + :param columns: The columns in the multivariate distribution. + :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. + :param group_by_clause: Any GROUP BY clause (starting with " GROUP BY " if not ""). + :param constant_clauses: Extra output columns in the outer SELECT clause, such + as ", _q.column_one AS k1, _q.column_two AS k2". Note the initial comma. + :param constants: Extra output columns in the inner SELECT clause. Used to + deliver columns to the outer select, such as ", column_one, column_two". + Note the initial comma. + :param suppress_count: a group smaller than this will be suppressed. + :param sample_count: this many samples will be taken from each partition. + """ + means = "".join(f", _q.m{i}" for i in range(len(columns))) + covs = "".join( + ( + f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" + f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ) + for iy in range(len(columns)) + for ix in range(iy + 1) + ) + subquery = self._inner_query(table, columns, predicates, sample_count) + # if there are any numeric columns we need at least + # two rows to make any (co)variances at all + suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" + return ( + f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" + f" FROM ({self._middle_query(columns, constants, subquery, group_by_clause)})" + f" AS _q{suppress_clause}" + ) + + def _inner_query( + self, + table: str, + columns: Sequence[Column], + predicates: Iterable[str], + sample_count: int | None, + ) -> str: + """Get the rows from the table that we are interested in.""" + preds = itertools.chain( + (self.query_predicate(col) for col in columns), + predicates, + ) + where = " AND ".join(preds) if preds else "" + if where: + where = " WHERE " + where + if sample_count is None: + return table + where + return ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) + + def _middle_query( + self, + columns: Sequence[Column], + constants: str, + inner_query: str, + group_by_clause: str, + ) -> str: + """Get the basic statistics (and constants) from the inner query.""" + multiples = "".join( + f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" + for iy, coly in enumerate(columns) + for ix, colx in enumerate(columns[: iy + 1]) + ) + avgs = "".join( + f", AVG({self.query_var(col.name)}) AS m{i}" + for i, col in enumerate(columns) + ) + return ( + f"SELECT COUNT(*) AS count{multiples}{avgs}{constants}" + f" FROM {inner_query}{group_by_clause}" + ) + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators for these columns.""" + # For the case of one column we'll use GaussianGenerator + if len(columns) < 2: + return [] + # All columns must be numeric + for c in columns: + ct = get_column_type(c) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + column_names = [c.name for c in columns] + table = columns[0].table.name + query = self.query(table, columns) + with engine.connect() as connection: + try: + covariates = connection.execute(text(query)).mappings().first() + except SQLAlchemyError as e: + logger.debug("SQL query %s failed with error %s", query, e) + return [] + if not covariates or covariates["c0_0"] is None: + return [] + return [ + MultivariateNormalGenerator( + table, + column_names, + query, + covariates, + self.function_name(), + ) + ] + + +class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Multivariate lognormal generator factory.""" + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "multivariate_lognormal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" + return f"COALESCE(0 < {column.name}, FALSE)" + + def query_var(self, column: str) -> str: + """Get the expression to query for, for this column.""" + return f"LN({column})" diff --git a/datafaker/generators/mimesis.py b/datafaker/generators/mimesis.py new file mode 100644 index 00000000..a0fa4268 --- /dev/null +++ b/datafaker/generators/mimesis.py @@ -0,0 +1,421 @@ +"""Generators using Mimesis.""" + +from typing import Any, Callable, Sequence, Union + +import mimesis +import mimesis.locales +from sqlalchemy import Column, Engine, text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time + +from datafaker.generators.base import ( + Buckets, + Generator, + GeneratorError, + GeneratorFactory, + get_column_type, +) +from datafaker.providers import DistributionProvider + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + +dist_gen = DistributionProvider() +generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + + +class MimesisGeneratorBase(Generator): + """Base class for a generator using Mimesis.""" + + def __init__( + self, + function_name: str, + ): + """ + Initialise a generator that uses Mimesis. + + :param function_name: is relative to 'generic', for example 'person.name'. + """ + super().__init__() + f = generic + for part in function_name.split("."): + if not hasattr(f, part): + raise GeneratorError( + f"Mimesis does not have a function {function_name}: {part} not found" + ) + f = getattr(f, part) + if not callable(f): + raise GeneratorError( + f"Mimesis object {function_name} is not a callable," + " so cannot be used as a generator" + ) + self._name = "generic." + function_name + self._generator_function = f + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self._generator_function() for _ in range(count)] + + +class MimesisGenerator(MimesisGeneratorBase): + """A generator using Mimesis.""" + + def __init__( + self, + function_name: str, + value_fn: Callable[[Any], float] | None = None, + buckets: Buckets | None = None, + ): + """ + Initialise a generator using Mimesis. + + :param function_name: is relative to 'generic', for example 'person.name'. + :param value_fn: Function to convert generator output to floats, if needed. The values + thus produced are compared against the buckets to estimate the fit. + :param buckets: The distribution of string lengths in the real data. If this is None + then the fit method will return None. + """ + super().__init__(function_name) + if buckets is None: + self._fit = None + return + samples = self.generate_data(400) + if value_fn: + samples = [value_fn(s) for s in samples] + self._fit = buckets.fit_from_values(samples) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return {} + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {} + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default if self._fit is None else self._fit + + +class MimesisGeneratorTruncated(MimesisGenerator): + """A string generator using Mimesis that must fit within a certain number of characters.""" + + def __init__( + self, + function_name: str, + length: int, + value_fn: Callable[[Any], float] | None = None, + buckets: Buckets | None = None, + ): + """Initialise a MimesisGeneratorTruncated.""" + self._length = length + super().__init__(function_name, value_fn, buckets) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.truncated_string" + + def name(self) -> str: + """Get the name of the generator.""" + return f"{self._name} [truncated to {self._length}]" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "subgen_fn": self._name, + "params": {}, + "length": self._length, + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "subgen_fn": self._name, + "params": {}, + "length": self._length, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self._generator_function()[: self._length] for _ in range(count)] + + +class MimesisDateTimeGenerator(MimesisGeneratorBase): + """DateTime generator using Mimesis.""" + + # pylint: disable=too-many-arguments too-many-positional-arguments + def __init__( + self, + column: Column, + function_name: str, + min_year: str, + max_year: str, + start: int, + end: int, + ) -> None: + """ + Initialise a MimesisDateTimeGenerator. + + :param column: The column to generate into + :param function_name: The name of the mimesis function + :param min_year: SQL expression extracting the minimum year + :param min_year: SQL expression extracting the maximum year + :param start: The actual first year found + :param end: The actual last year found + """ + super().__init__(function_name) + self._column = column + self._max_year = max_year + self._min_year = min_year + self._start = start + self._end = end + + @classmethod + def make_singleton( + cls, column: Column, engine: Engine, function_name: str + ) -> Sequence[Generator]: + """Make the appropriate generation configuration for this column.""" + extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" + max_year = f"MAX({extract_year})" + min_year = f"MIN({extract_year})" + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" + ) + ).first() + if result is None or result.start is None or result.end is None: + return [] + return [ + MimesisDateTimeGenerator( + column, + function_name, + min_year, + max_year, + int(result.start), + int(result.end), + ) + ] + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "start": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__start"]' + ), + "end": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__end"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "start": self._start, + "end": self._end, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + return { + f"{self._column.name}__start": { + "clause": self._min_year, + "comment": ( + f"Earliest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), + }, + f"{self._column.name}__end": { + "clause": self._max_year, + "comment": ( + f"Latest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), + }, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + self._generator_function(start=self._start, end=self._end) + for _ in range(count) + ] + + +class MimesisStringGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return strings.""" + + GENERATOR_NAMES = [ + "address.calling_code", + "address.city", + "address.continent", + "address.country", + "address.country_code", + "address.postal_code", + "address.province", + "address.street_number", + "address.street_name", + "address.street_suffix", + "person.blood_type", + "person.email", + "person.first_name", + "person.last_name", + "person.full_name", + "person.gender", + "person.language", + "person.nationality", + "person.occupation", + "person.password", + "person.title", + "person.university", + "person.username", + "person.worldview", + "text.answer", + "text.color", + "text.level", + "text.quote", + "text.sentence", + "text.text", + "text.word", + ] + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + column_type = get_column_type(column) + if not isinstance(column_type, String): + return [] + try: + buckets = Buckets.make_buckets( + engine, + column.table.name, + f"LENGTH({column.name})", + ) + fitness_fn = len + except SQLAlchemyError: + # Some column types that appear to be strings (such as enums) + # cannot have their lengths measured. In this case we cannot + # detect fitness using lengths. + buckets = None + fitness_fn = None + length = column_type.length + if length: + return list( + map( + lambda gen: MimesisGeneratorTruncated( + gen, length, fitness_fn, buckets + ), + self.GENERATOR_NAMES, + ) + ) + return list( + map( + lambda gen: MimesisGenerator(gen, fitness_fn, buckets), + self.GENERATOR_NAMES, + ) + ) + + +class MimesisFloatGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return floating point numbers.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + if not isinstance(get_column_type(column), Numeric): + return [] + return list( + map( + MimesisGenerator, + [ + "person.height", + ], + ) + ) + + +class MimesisDateGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return dates.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Date): + return [] + return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") + + +class MimesisDateTimeGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return datetimes.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, DateTime): + return [] + return MimesisDateTimeGenerator.make_singleton( + column, engine, "datetime.datetime" + ) + + +class MimesisTimeGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return times.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Time): + return [] + return [MimesisGenerator("datetime.time")] + + +class MimesisIntegerGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return integers.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + return [MimesisGenerator("person.weight")] diff --git a/datafaker/generators/partitioned.py b/datafaker/generators/partitioned.py new file mode 100644 index 00000000..f14af736 --- /dev/null +++ b/datafaker/generators/partitioned.py @@ -0,0 +1,560 @@ +"""Powerful generators for numbers, choices and related missingness.""" + +from dataclasses import dataclass +from itertools import chain, combinations +from typing import Any, Iterable, Sequence, Union + +import sqlalchemy +from sqlalchemy import Column, Connection, Engine, RowMapping, text +from sqlalchemy.types import Integer, Numeric + +from datafaker.generators.base import Generator, dist_gen, get_column_type +from datafaker.generators.continuous import MultivariateNormalGeneratorFactory +from datafaker.utils import T, logger + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + + +def text_list(items: Iterable[str]) -> str: + """Concatenate the items with commas and one "and".""" + item_i = iter(items) + try: + last_item = next(item_i) + except StopIteration: + return "" + try: + so_far = next(item_i) + except StopIteration: + return last_item + for item in item_i: + so_far += ", " + last_item + last_item = item + return so_far + " and " + last_item + + +@dataclass +class RowPartition: + """A partition where all the rows have the same pattern of NULLs.""" + + query: str + # list of numeric columns + included_numeric: list[Column] + # map of indices to column names that are being grouped by. + # The indices are indices of where they need to be inserted into + # the generator outputs. + included_choice: dict[int, str] + # map of column names to clause that defines the partition + # such as "mycolumn IS NULL" + excluded_columns: dict[str, str] + # map of constant outputs that need to be inserted into the + # list of included column values (so once the generator has + # been run and the included_choice values have been + # added): {index: value} + constant_outputs: dict[int, Any] + # The actual covariates from the source database + covariates: Sequence[RowMapping] + + def comment(self) -> str: + """Make an appropriate comment for this partition.""" + caveat = "" + if self.included_choice: + caveat = f" (for each possible value of {text_list(self.included_choice.values())})" + if not self.included_numeric: + return f"Number of rows for which {text_list(self.excluded_columns.values())}{caveat}" + if not self.excluded_columns: + where = "" + else: + where = f" where {text_list(self.excluded_columns.values())}" + if len(self.included_numeric) == 1: + return ( + f"Mean and variance for column {self.included_numeric[0].name}{where}." + ) + return ( + "Means and covariate matrix for the columns " + f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" + " produce the relatedness between these in the fake data." + ) + + +@dataclass +class NullableColumn: + """A reference to a nullable column whose nullability is part of a partitioning.""" + + column: Column + # The bit (power of two) of the number of the partition in the partition sizes list + bitmask: int + + +class PartitionCountQuery: + """Query, result and comment for the row counts of the null pattern partitions.""" + + def __init__( + self, + connection: Connection, + query: str, + table_name: str, + nullable_columns: Iterable[NullableColumn], + ) -> None: + """ + Initialise the partition count query. + + :param connection: Database connection. + :param query: The query getting the row counts of the null pattern partitions. + :param table_name: The name of the table being queried. + :param nullable_columns: The columns that are being checked for nullness. + """ + self.query = query + rows = connection.execute(text(query)).mappings().fetchall() + self.results = [dict(row) for row in rows] + self.comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table_name} being null" + ) + + +class NullPartitionedNormalGenerator(Generator): + """ + A generator of mixed numeric and non-numeric data. + + Generates data that matches the source data in + missingness, choice of non-numeric data and numeric + data. + + For the numeric data to be generated, samples of rows for each + combination of non-numeric values and missingness. If any such + combination has only one line in the source data (or sample of + the source data if sampling), it will not be generated as a + covariate matrix cannot be generated from one source row + (although if the data is all non-numeric values and nulls, single + rows are used because no covariate matrix is required for this). + """ + + # pylint: disable=too-many-arguments too-many-positional-arguments + def __init__( + self, + query_name: str, + partitions: dict[int, RowPartition], + function_name: str = "grouped_multivariate_lognormal", + name_suffix: str | None = None, + partition_count_query: PartitionCountQuery | None = None, + ): + """Initialise a NullPartitionedNormalGenerator.""" + self._query_name = query_name + self._partitions = partitions + self._function_name = function_name + self._partition_count_query = partition_count_query + if name_suffix: + self._name = f"null-partitioned {function_name} [{name_suffix}]" + else: + self._name = f"null-partitioned {function_name}" + + def name(self) -> str: + """Get the name of the generator.""" + return self._name + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.alternatives" + + def _nominal_kwargs_with_combinations( + self, index: int, partition: RowPartition + ) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml`` for a single partition.""" + count = ( + 'sum(r["count"] for r in' + f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + ) + if not partition.included_numeric and not partition.included_choice: + return { + "count": count, + "name": '"constant"', + "params": {"value": [None] * len(partition.constant_outputs)}, + } + covariates = { + "covs": f'SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"]' + } + if not partition.constant_outputs: + return { + "count": count, + "name": f'"{self._function_name}"', + "params": covariates, + } + return { + "count": count, + "name": '"with_constants_at"', + "params": { + "constants_at": partition.constant_outputs, + "subgen": f'"{self._function_name}"', + "params": covariates, + }, + } + + def _count_query_name(self) -> str: + return f"auto__cov__{self._query_name}__counts" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "alternative_configs": [ + self._nominal_kwargs_with_combinations(index, self._partitions[index]) + for index in range(len(self._partitions)) + ], + "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', + } + + def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" + partitions = { + f"auto__cov__{self._query_name}__alt_{index}": { + "comment": partition.comment(), + "query": partition.query, + } + for index, partition in self._partitions.items() + } + if not self._partition_count_query: + return partitions + return { + self._count_query_name(): { + "comment": self._partition_count_query.comment, + "query": self._partition_count_query.query, + }, + **partitions, + } + + def _actual_kwargs_with_combinations( + self, partition: RowPartition + ) -> dict[str, Any]: + count = sum(row["count"] for row in partition.covariates) + if not partition.included_numeric and not partition.included_choice: + return { + "count": count, + "name": "constant", + "params": {"value": [None] * len(partition.excluded_columns)}, + } + covariates = { + "covs": partition.covariates, + } + if not partition.constant_outputs: + return { + "count": count, + "name": self._function_name, + "params": covariates, + } + return { + "count": count, + "name": "with_constants_at", + "params": { + "constants_at": partition.constant_outputs, + "subgen": self._function_name, + "params": covariates, + }, + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + if self._partition_count_query is None: + counts = None + else: + counts = self._partition_count_query.results + return { + "alternative_configs": [ + self._actual_kwargs_with_combinations(self._partitions[index]) + for index in range(len(self._partitions)) + ], + "counts": counts, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate 'count' random data points for this column.""" + kwargs = self.actual_kwargs() + return [dist_gen.alternatives(**kwargs) for _ in range(count)] + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default + + +def is_numeric(col: Column) -> bool: + """Test if this column stores a numeric value.""" + ct = get_column_type(col) + return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys + + +def powerset(xs: list[T]) -> Iterable[Iterable[T]]: + """Get a list of all sublists of ``input``.""" + return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) + + +# pylint: disable=too-many-instance-attributes +class NullPatternPartition: + """Get the definition of a partition (in other words, what makes it not another partition).""" + + def __init__( + self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] + ): + """Initialise a pattern of nulls which can be queried for.""" + self.index = sum(nc.bitmask for nc in partition_nonnulls) + nonnull_columns = {nc.column.name for nc in partition_nonnulls} + self.included_numeric: list[Column] = [] + self.included_choice: dict[int, str] = {} + self.group_by_clause = "" + self.constant_clauses = "" + self.constants = "" + self.excluded: dict[str, str] = {} + self.predicates: list[str] = [] + self.nones: dict[int, None] = {} + for col_index, column in enumerate(columns): + col_name = column.name + if col_name in nonnull_columns or not column.nullable: + if is_numeric(column): + self.included_numeric.append(column) + else: + index = len(self.included_numeric) + len(self.included_choice) + self.included_choice[index] = col_name + if self.group_by_clause: + self.group_by_clause += ", " + col_name + else: + self.group_by_clause = " GROUP BY " + col_name + self.constant_clauses += f", _q.{col_name} AS k{index}" + self.constants += ", " + col_name + else: + self.excluded[col_name] = f"{col_name} IS NULL" + self.predicates.append(f"{col_name} IS NULL") + self.nones[col_index] = None + + +class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Produces null partitioned generators, for complex interdependent data.""" + + SAMPLE_COUNT = MAXIMUM_CHOICES + SUPPRESS_COUNT = 7 + EMPTY_RESULT = [ + RowMapping( + parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), + processors=None, + key_to_index={"count": 0}, + data=(0,), + ) + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "grouped_multivariate_normal" + + def query_predicate(self, column: Column) -> str: + """Get a SQL expression that is true when ``column`` is available for analysis.""" + if is_numeric(column): + # x <> x + 1 ensures that x is not infinity or NaN + return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" + return f"{column.name} IS NOT NULL" + + def query_var(self, column: str) -> str: + """Return the expression we are querying for in this column.""" + return column + + def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: + """Get a list of nullable columns together with bitmasks.""" + out: list[NullableColumn] = [] + for col in columns: + if col.nullable: + out.append( + NullableColumn( + column=col, + bitmask=2 ** len(out), + ) + ) + return out + + def get_partition_count_query( + self, ncs: list[NullableColumn], table: str, where: str | None = None + ) -> str: + """ + Get a SQL expression returning columns ``count`` and ``index``. + + Each row returned represents one of the null pattern partitions. + ``index`` is the bitmask of all those nullable columns that are not null for + this partition, and ``count`` is the total number of rows in this partition. + """ + index_exp = " + ".join( + f"CASE WHEN {self.query_predicate(nc.column)} THEN {nc.bitmask} ELSE 0 END" + for nc in ncs + ) + if where is None: + return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' + return ( + 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' + f' {index_exp} AS "index"' + f' FROM {table} GROUP BY "index") AS _q {where}' + ) + + def _get_row_partition( + self, + table: str, + partition: NullPatternPartition, + suppress_count: int = 1, + sample_count: int | None = None, + ) -> RowPartition: + """Get the RowPartition from a NullPatternPartition.""" + query = self.query( + table=table, + columns=partition.included_numeric, + predicates=partition.predicates, + group_by_clause=partition.group_by_clause, + constants=partition.constants, + constant_clauses=partition.constant_clauses, + suppress_count=suppress_count, + sample_count=sample_count, + ) + return RowPartition( + query, + partition.included_numeric, + partition.included_choice, + partition.excluded, + partition.nones, + [], + ) + + # pylint: disable=too-many-arguments too-many-positional-arguments + def _get_generator( + self, + connection: Connection, + table_name: str, + columns: list[Column], + nullable_columns: list[NullableColumn], + where: str | None = None, + name_suffix: str | None = None, + suppress_count: int = 1, + sample_count: int | None = None, + ) -> NullPartitionedNormalGenerator | None: + query = self.get_partition_count_query(nullable_columns, table_name, where) + partitions: dict[int, RowPartition] = {} + for partition_nonnulls in powerset(nullable_columns): + partition_def = NullPatternPartition(columns, partition_nonnulls) + partitions[partition_def.index] = self._get_row_partition( + table_name, + partition_def, + suppress_count=suppress_count, + sample_count=sample_count, + ) + if not self._execute_partition_queries(connection, partitions): + return None + return NullPartitionedNormalGenerator( + f"{table_name}__{columns[0].name}", + partitions, + self.function_name(), + name_suffix=name_suffix, + partition_count_query=PartitionCountQuery( + connection, + query, + table_name, + nullable_columns, + ), + ) + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get any appropriate generators for these columns.""" + if len(columns) < 2: + return [] + nullable_columns = self.get_nullable_columns(columns) + if not nullable_columns: + return [] + table = columns[0].table.name + gens: list[Generator | None] = [] + try: + with engine.connect() as connection: + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + ) + ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + name_suffix="sampled", + sample_count=self.SAMPLE_COUNT, + ) + ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + name_suffix="sampled and suppressed", + suppress_count=self.SUPPRESS_COUNT, + sample_count=self.SAMPLE_COUNT, + ) + ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + name_suffix="suppressed", + suppress_count=self.SUPPRESS_COUNT, + ) + ) + except sqlalchemy.exc.DatabaseError as exc: + logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) + return [] + return [gen for gen in gens if gen] + + def _execute_partition_queries( + self, + connection: Connection, + partitions: dict[int, RowPartition], + ) -> bool: + """ + Execute the query in each partition, filling in the covariates. + + :return: True if all the partitions work, False if any of them fail. + """ + found_nonzero = False + for rp in partitions.values(): + covs = connection.execute(text(rp.query)).mappings().fetchall() + if not covs or covs.count == 0 or covs[0]["count"] is None: + rp.covariates = self.EMPTY_RESULT + else: + rp.covariates = covs + found_nonzero = True + return found_nonzero + + +class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): + """ + A generator for numeric and non-numeric columns. + + Any values could be null, the distributions of the nonnull numeric columns + depend on each other and the other non-numeric column values. + """ + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "grouped_multivariate_lognormal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression testing if the value in this column should be used.""" + if is_numeric(column): + # x <> x + 1 ensures that x is not infinity or NaN + return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" + return f"{column.name} IS NOT NULL" + + def query_var(self, column: str) -> str: + """Get the variable or expression we are querying for this column.""" + return f"LN({column})" diff --git a/datafaker/interactive.py b/datafaker/interactive.py deleted file mode 100644 index 7f1b8795..00000000 --- a/datafaker/interactive.py +++ /dev/null @@ -1,1700 +0,0 @@ -from abc import ABC, abstractmethod -import cmd -from collections.abc import Mapping -import csv -from dataclasses import dataclass -from enum import Enum -import functools -import itertools -from pathlib import Path -import re -from typing import Iterable - -import sqlalchemy -from prettytable import PrettyTable -from sqlalchemy import Column, MetaData, Table, text, ForeignKey - -from datafaker.generators import Generator, PredefinedGenerator, everything_factory -from datafaker.utils import ( - create_db_engine, - fk_refers_to_ignored_table, - logger, - primary_private_fks, - table_is_private, -) - -# Monkey patch pyreadline3 v3.5 so that it works with Python 3.13 -# Windows users can install pyreadline3 to get tab completion working. -# See https://github.com/pyreadline3/pyreadline3/issues/37 -try: - import readline - if not hasattr(readline, "backend"): - readline.backend = "readline" -except: - pass - -def or_default(v, d): - """ Returns v if it isn't None, otherwise d. """ - return d if v is None else v - -class TableType(Enum): - GENERATE = "generate" - IGNORE = "ignore" - VOCABULARY = "vocabulary" - PRIVATE = "private" - EMPTY = "empty" - -TYPE_LETTER = { - TableType.GENERATE: "G", - TableType.IGNORE: "I", - TableType.VOCABULARY: "V", - TableType.PRIVATE: "P", - TableType.EMPTY: "e", -} - -TYPE_PROMPT = { - TableType.GENERATE: "(table: {}) ", - TableType.IGNORE: "(table: {} (ignore)) ", - TableType.VOCABULARY: "(table: {} (vocab)) ", - TableType.PRIVATE: "(table: {} (private)) ", - TableType.EMPTY: "(table: {} (empty))", -} - -@dataclass -class TableEntry: - name: str # name of the table - - -class AskSaveCmd(cmd.Cmd): - intro = "Do you want to save this configuration?" - prompt = "(yes/no/cancel) " - file = None - def __init__(self): - super().__init__() - self.result = "" - def do_yes(self, _arg): - self.result = "yes" - return True - def do_no(self, _arg): - self.result = "no" - return True - def do_cancel(self, _arg): - self.result = "cancel" - return True - - -def fk_column_name(fk: ForeignKey): - if fk_refers_to_ignored_table(fk): - return f"{fk.target_fullname} (ignored)" - return fk.target_fullname - - -class DbCmd(ABC, cmd.Cmd): - INFO_NO_MORE_TABLES = "There are no more tables" - ERROR_ALREADY_AT_START = "Error: Already at the start" - ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" - ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" - ROW_COUNT_MSG = "Total row count: {}" - - @abstractmethod - def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: - ... - - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): - super().__init__() - self.config = config - self.metadata = metadata - self.table_entries: list[TableEntry] = [] - tables_config: Mapping = config.get("tables", {}) - if type(tables_config) is not dict: - tables_config = {} - for name in metadata.tables.keys(): - table_config = tables_config.get(name, {}) - if type(table_config) is not dict: - table_config = {} - entry = self.make_table_entry(name, table_config) - if entry is not None: - self.table_entries.append(entry) - self.table_index = 0 - self.engine = create_db_engine(src_dsn, schema_name=src_schema) - def __enter__(self): - return self - def __exit__(self, exc_type, exc_val, exc_tb): - self.engine.dispose() - - def print(self, text: str, *args, **kwargs): - print(text.format(*args, **kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]): - output = PrettyTable() - output.field_names = headings - for row in rows: - output.add_row(row) - print(output) - def print_table_by_columns(self, columns: dict[str, list[str]]): - output = PrettyTable() - row_count = max([len(col) for col in columns.values()]) - for field_name, data in columns.items(): - output.add_column(field_name, data + [None] * (row_count - len(data))) - print(output) - def print_results(self, result): - self.print_table( - list(result.keys()), - [list(row) for row in result.all()] - ) - def ask_save(self): - ask = AskSaveCmd() - ask.cmdloop() - return ask.result - - def set_table_index(self, index) -> bool: - if 0 <= index and index < len(self.table_entries): - self.table_index = index - self.set_prompt() - return True - return False - def next_table(self, report="No more tables"): - if not self.set_table_index(self.table_index + 1): - self.print(report) - return False - return True - def table_name(self): - return self.table_entries[self.table_index].name - def table_metadata(self) -> Table: - return self.metadata.tables[self.table_name()] - def get_column_names(self) -> list[str]: - return [ - col.name - for col in self.table_metadata().columns - ] - def report_columns(self): - self.print_table(["name", "type", "primary", "nullable", "foreign key"], [ - [name, str(col.type), col.primary_key, col.nullable, ", ".join( - [fk_column_name(fk) for fk in col.foreign_keys] - )] - for name, col in self.table_metadata().columns.items() - ]) - def get_table_config(self, table_name: str) -> dict[str, any]: - ts = self.config.get("tables", None) - if type(ts) is not dict: - return {} - t = ts.get(table_name) - return t if type(t) is dict else {} - def set_table_config(self, table_name: str, config: dict[str, any]): - ts = self.config.get("tables", None) - if type(ts) is not dict: - self.config["tables"] = {table_name: config} - return - ts[table_name] = config - def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: - src_stats = self.config.get("src-stats", []) - new_src_stats = [] - for stat in src_stats: - if not stat.get("name", "").startswith(prefix): - new_src_stats.append(stat) - self.config["src-stats"] = new_src_stats - return new_src_stats - def get_nonnull_columns(self, table_name: str): - metadata_table = self.metadata.tables[table_name] - return [ - str(name) - for name, column in metadata_table.columns.items() - if column.nullable - ] - def find_entry_index_by_table_name(self, table_name) -> int | None: - return next( - (i for i,entry in enumerate(self.table_entries) if entry.name == table_name), - None, - ) - def find_entry_by_table_name(self, table_name) -> TableEntry | None: - for e in self.table_entries: - if e.name == table_name: - return e - return None - def do_counts(self, _arg): - "Report the column names with the counts of nulls in them" - if len(self.table_entries) <= self.table_index: - return - table_name = self.table_name() - nonnull_columns = self.get_nonnull_columns(table_name) - colcounts = [ - ", COUNT({0}) AS {0}".format(nnc) - for nnc in nonnull_columns - ] - with self.engine.connect() as connection: - result = connection.execute( - text("SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( - table=table_name, - colcounts="".join(colcounts), - )) - ).first() - if result is None: - self.print("Could not count rows in table {0}", table_name) - return - row_count = result.row_count - self.print(self.ROW_COUNT_MSG, row_count) - self.print_table(["Column", "NULL count"], [ - [name, row_count - count] - for name, count in result._mapping.items() - if name != "row_count" - ]) - - def do_select(self, arg): - "Run a select query over the database and show the first 50 results" - MAX_SELECT_ROWS = 50 - with self.engine.connect() as connection: - try: - result = connection.execute( - text("SELECT " + arg) - ) - except sqlalchemy.exc.DatabaseError as exc: - self.print("Failed to execute: {}", exc) - return - row_count = result.rowcount - self.print(self.ROW_COUNT_MSG, row_count) - if 50 < row_count: - self.print("Showing the first {} rows", MAX_SELECT_ROWS) - fields = list(result.keys()) - rows = [ - row._tuple() - for row in result.fetchmany(MAX_SELECT_ROWS) - ] - self.print_table(fields, rows) - - def do_peek(self, arg: str): - """ - Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. - Use 'peek' to see a sample of the current column(s). - Rows that are enitrely null are suppressed. - """ - MAX_PEEK_ROWS = 25 - if len(self.table_entries) <= self.table_index: - return - table_name = self.table_name() - col_names = arg.split() - if not col_names: - col_names = self.get_column_names() - nonnulls = [cn + " IS NOT NULL" for cn in col_names] - with self.engine.connect() as connection: - query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( - cols=",".join(col_names), - table=table_name, - where="WHERE" if nonnulls else "", - nonnull=" OR ".join(nonnulls), - max=MAX_PEEK_ROWS, - ) - try: - result = connection.execute(text(query)) - except Exception as exc: - self.print(f'SQL query "{query}" caused exception {exc}') - return - rows = [ - row._tuple() - for row in result.fetchmany(MAX_PEEK_ROWS) - ] - self.print_table(list(result.keys()), rows) - - def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int): - if len(self.table_entries) <= self.table_index: - return [] - return [ - col - for col in self.table_metadata().columns.keys() - if col.startswith(text) - ] - - -@dataclass -class TableCmdTableEntry(TableEntry): - old_type: TableType - new_type: TableType - -class TableCmd(DbCmd): - intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" - doc_leader = """Use the commands 'ignore', 'vocabulary', -'private', 'empty' or 'generate' to set the table's type. Use 'next' or -'previous' to change table. Use 'tables' and 'columns' for -information about the database. Use 'data', 'peek', 'select' or -'count' to see some data contained in the current table. Use 'quit' -to exit this program.""" - prompt = "(tableconf) " - file = None - WARNING_TEXT_VOCAB_TO_NON_VOCAB = "Vocabulary table {0} references non-vocabulary table {1}" - WARNING_TEXT_NON_EMPTY_TO_EMPTY = "Empty table {1} referenced from non-empty table {0}. {1} will need stories." - WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" - WARNING_TEXT_POTENTIAL_PROBLEMS = "NOTE: The following table types might cause problems later:" - NOTE_TEXT_NO_CHANGES = "You have made no changes." - NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - - def make_table_entry(self, name: str, table: Mapping) -> TableEntry: - if table.get("ignore", False): - return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) - if table.get("vocabulary_table", False): - return TableCmdTableEntry(name, TableType.VOCABULARY, TableType.VOCABULARY) - if table.get("primary_private", False): - return TableCmdTableEntry(name, TableType.PRIVATE, TableType.PRIVATE) - if table.get("num_rows_per_pass", 1) == 0: - return TableCmdTableEntry(name, TableType.EMPTY, TableType.EMPTY) - return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) - - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): - super().__init__(src_dsn, src_schema, metadata, config) - self.set_prompt() - - def set_prompt(self): - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) - else: - self.prompt = "(table) " - def set_type(self, t_type: TableType): - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - entry.new_type = t_type - def _copy_entries(self) -> None: - for entry in self.table_entries: - entry: TableCmdTableEntry - if entry.old_type != entry.new_type: - table = self.get_table_config(entry.name) - if entry.old_type == TableType.EMPTY and table.get("num_rows_per_pass", 1) == 0: - table["num_rows_per_pass"] = 1 - if entry.new_type == TableType.IGNORE: - table["ignore"] = True - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - elif entry.new_type == TableType.VOCABULARY: - table.pop("ignore", None) - table["vocabulary_table"] = True - table.pop("primary_private", None) - elif entry.new_type == TableType.PRIVATE: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table["primary_private"] = True - elif entry.new_type == TableType.EMPTY: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - table["num_rows_per_pass"] = 0 - else: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - self.set_table_config(entry.name, table) - - def _get_referenced_tables(self, from_table_name: str) -> set[str]: - from_meta = self.metadata.tables[from_table_name] - return { - fk.column.table.name - for col in from_meta.columns - for fk in col.foreign_keys - } - - def _sanity_check_failures(self) -> list[tuple[str, str, str]]: - """ Find tables that reference each other that should not given their types. """ - failures = [] - for from_entry in self.table_entries: - from_entry: TableCmdTableEntry - from_t = from_entry.new_type - if from_t == TableType.VOCABULARY: - referenced = self._get_referenced_tables(from_entry.name) - for ref in referenced: - to_entry = self.find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type != TableType.VOCABULARY: - failures.append(( - self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - from_entry.name, - to_entry.name, - )) - return failures - - def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: - """ Find tables that reference each other that might cause problems given their types. """ - warnings = [] - for from_entry in self.table_entries: - from_entry: TableCmdTableEntry - from_t = from_entry.new_type - if from_t in {TableType.GENERATE, TableType.PRIVATE}: - referenced = self._get_referenced_tables(from_entry.name) - for ref in referenced: - to_entry = self.find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type in {TableType.EMPTY, TableType.IGNORE}: - warnings.append(( - self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - from_entry.name, - to_entry.name, - )) - return warnings - - - def do_quit(self, _arg): - "Check the updates, save them if desired and quit the configurer." - count = 0 - for entry in self.table_entries: - if entry.old_type != entry.new_type: - count += 1 - self.print( - self.NOTE_TEXT_CHANGING, - entry.name, - entry.old_type.value, - entry.new_type.value, - ) - if count == 0: - self.print(self.NOTE_TEXT_NO_CHANGES) - failures = self._sanity_check_failures() - if failures: - self.print(self.WARNING_TEXT_PROBLEMS_EXIST) - for (text, from_t, to_t) in failures: - self.print(text, from_t, to_t) - warnings = self._sanity_check_warnings() - if warnings: - self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) - for (text, from_t, to_t) in warnings: - self.print(text, from_t, to_t) - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - def do_tables(self, _arg): - "list the tables with their types" - for entry in self.table_entries: - old = entry.old_type - new = entry.new_type - becomes = " " if old == new else "->" + TYPE_LETTER[new] - self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) - def do_next(self, arg): - "'next' = go to the next table, 'next tablename' = go to table 'tablename'" - if arg: - # Find the index of the table called _arg, if any - index = self.find_entry_index_by_table_name(arg) - if index is None: - self.print(self.ERROR_NO_SUCH_TABLE, arg) - return - self.set_table_index(index) - return - self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text, line, begidx, endidx): - return [ - entry.name - for entry in self.table_entries - if entry.name.startswith(text) - ] - def do_previous(self, _arg): - "Go to the previous table" - if not self.set_table_index(self.table_index - 1): - self.print(self.ERROR_ALREADY_AT_START) - def do_ignore(self, _arg): - "Set the current table as ignored, and go to the next table" - self.set_type(TableType.IGNORE) - self.print("Table {} set as ignored", self.table_name()) - self.next_table() - def do_vocabulary(self, _arg): - "Set the current table as a vocabulary table, and go to the next table" - self.set_type(TableType.VOCABULARY) - self.print("Table {} set to be a vocabulary table", self.table_name()) - self.next_table() - def do_private(self, _arg): - "Set the current table as a primary private table (such as the table of patients)" - self.set_type(TableType.PRIVATE) - self.print("Table {} set to be a primary private table", self.table_name()) - self.next_table() - def do_generate(self, _arg): - "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" - self.set_type(TableType.GENERATE) - self.print("Table {} generate", self.table_name()) - self.next_table() - def do_empty(self, _arg): - "Set the current table as empty; no generators will be run for it" - self.set_type(TableType.EMPTY) - self.print("Table {} empty", self.table_name()) - self.next_table() - def do_columns(self, _arg): - "Report the column names and metadata" - self.report_columns() - def do_data(self, arg: str): - """ - Report some data. - 'data' = report a random ten lines, - 'data 20' = report a random 20 lines, - 'data 20 ColumnName' = report a random twenty entries from ColumnName, - 'data 20 ColumnName 30' = report a random twenty entries from ColumnName of length at least 30, - """ - args = arg.split() - column = None - number = None - arg_index = 0 - min_length = 0 - table_metadata = self.table_metadata() - if arg_index < len(args) and args[arg_index].isdigit(): - number = int(args[arg_index]) - arg_index += 1 - if arg_index < len(args) and args[arg_index] in table_metadata.columns: - column = args[arg_index] - arg_index += 1 - if arg_index < len(args) and args[arg_index].isdigit(): - min_length = int(args[arg_index]) - arg_index += 1 - if arg_index != len(args): - self.print( - """Did not understand these arguments -The format is 'data [entries] [column-name [minimum-length]]' where [] means optional text. -Type 'columns' to find out valid column names for this table. -Type 'help data' for examples.""" - ) - return - if column is None: - if number is None: - number = 10 - self.print_row_data(number) - else: - if number is None: - number = 48 - self.print_column_data(column, number, min_length) - def complete_data(self, text, line, begidx, endidx): - previous_parts = line[:begidx - 1].split() - if len(previous_parts) != 2: - return [] - table_metadata = self.table_metadata() - return [ - k for k in table_metadata.columns.keys() - if k.startswith(text) - ] - - def print_column_data(self, column: str, count: int, min_length: int): - where = f"WHERE {column} IS NOT NULL" - if 0 < min_length: - where = "WHERE LENGTH({column}) >= {len}".format( - column=column, - len=min_length, - ) - with self.engine.connect() as connection: - result = connection.execute( - text("SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - column=column, - count=count, - where=where, - )) - ) - self.columnize([str(x[0]) for x in result.all()]) - - def print_row_data(self, count: int): - with self.engine.connect() as connection: - result = connection.execute( - text("SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - count=count, - )) - ) - if result is None: - self.print("No rows in this table!") - return - self.print_results(result) - - -def update_config_tables(src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): - with TableCmd(src_dsn, src_schema, metadata, config) as tc: - tc.cmdloop() - return tc.config - - -@dataclass -class MissingnessType: - SAMPLED="column_presence.sampled" - SAMPLED_QUERY=( - "SELECT COUNT(*) AS row_count, {result_names} FROM " - "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" - " AS __t GROUP BY {result_names}" - ) - name: str - query: str - comment: str - columns: list[str] - @classmethod - def sampled_query(cls, table, count, column_names) -> str: - result_names = ", ".join([ - "{0}__is_null".format(c) - for c in column_names - ]) - column_is_nulls = ", ".join([ - "{0} IS NULL AS {0}__is_null".format(c) - for c in column_names - ]) - return cls.SAMPLED_QUERY.format( - result_names=result_names, - column_is_nulls=column_is_nulls, - table=table, - count=count, - ) - - -@dataclass -class MissingnessCmdTableEntry(TableEntry): - old_type: MissingnessType - new_type: MissingnessType - - -class MissingnessCmd(DbCmd): - intro = "Interactive missingness configuration. Type ? for help.\n" - doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for -the current table. Use commands 'next' and 'previous' to change the -current table. Use 'tables' to list the tables and 'count' to show -how many NULLs exist in each column. Use 'peek' or 'select' to see -data from the database. Use 'quit' to exit this tool.""" - prompt = "(missingness) " - file = None - PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') - - def find_missingness_query(self, missingness_generator: Mapping) -> tuple[str | None, str | None] | None: - """ Find query and comment from src-stats for the passed missingness generator. """ - kwargs = missingness_generator.get("kwargs", {}) - patterns = kwargs.get("patterns", "") - pattern_match = self.PATTERN_RE.match(patterns) - if pattern_match: - key = pattern_match.group(1) - for src_stat in self.config["src-stats"]: - if src_stat.get("name") == key: - return (src_stat.get("query", None), src_stat.get("comment", None)) - return None - def make_table_entry(self, name: str, table: Mapping) -> TableEntry: - if table.get("ignore", False): - return None - if table.get("vocabulary_table", False): - return None - if table.get("num_rows_per_pass", 1) == 0: - return None - mgs = table.get("missingness_generators", []) - old = None - nonnull_columns = self.get_nonnull_columns(name) - if not nonnull_columns: - return None - if not mgs: - old = MissingnessType( - name="none", - query="", - comment="", - columns=[], - ) - elif len(mgs) == 1: - mg = mgs[0] - mg_name = mg.get("name", None) - if mg_name is not None: - query_comment = self.find_missingness_query(mg) - if query_comment is not None: - (query, comment) = query_comment - old = MissingnessType( - name=mg_name, - query=query, - comment=comment, - columns=mg.get("columns_assigned", []), - ) - if old is None: - return None - return MissingnessCmdTableEntry( - name=name, - old_type=old, - new_type=old, - ) - - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): - super().__init__(src_dsn, src_schema, metadata, config) - self.set_prompt() - - def set_prompt(self): - if self.table_index < len(self.table_entries): - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - nt = entry.new_type - if nt is None: - self.prompt = "(missingness for {0}) ".format(entry.name) - else: - self.prompt = "(missingness for {0}: {1}) ".format(entry.name, nt.name) - else: - self.prompt = "(missingness) " - def set_type(self, t_type: TableType): - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - entry.new_type = t_type - def _copy_entries(self) -> None: - src_stats = self._remove_prefix_src_stats("missing_auto__") - for entry in self.table_entries: - entry: MissingnessCmdTableEntry - table = self.get_table_config(entry.name) - if entry.new_type is None or entry.new_type.name == "none": - table.pop("missingness_generators", None) - else: - src_stat_key = "missing_auto__{0}__0".format(entry.name) - table["missingness_generators"] = [{ - "name": entry.new_type.name, - "kwargs": {"patterns": 'SRC_STATS["{0}"]["results"]'.format(src_stat_key)}, - "columns": entry.new_type.columns, - }] - src_stats.append({ - "name": src_stat_key, - "query": entry.new_type.query, - "comments": [] if entry.new_type.comment is None else [entry.new_type.comment], - }) - self.set_table_config(entry.name, table) - - def do_quit(self, _arg): - "Check the updates, save them if desired and quit the configurer." - count = 0 - for entry in self.table_entries: - if entry.old_type != entry.new_type: - count += 1 - if entry.old_type is None: - self.print("Putting generator {0} on table {1}", entry.name, entry.new_type.name) - elif entry.new_type is None: - self.print("Deleting generator {1} from table {0}", entry.name, entry.old_type.name) - else: - self.print( - "Changing {0} from {1} to {2}", - entry.name, - entry.old_type.name, - entry.new_type.name, - ) - if count == 0: - self.print("You have made no changes.") - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - def do_tables(self, arg): - "list the tables with their types" - for entry in self.table_entries: - old = "-" if entry.old_type is None else entry.old_type.name - new = "-" if entry.new_type is None else entry.new_type.name - desc = new if old == new else "{0}->{1}".format(old, new) - self.print("{0} {1}", entry.name, desc) - def do_next(self, arg): - "'next' = go to the next table, 'next tablename' = go to table 'tablename'" - if arg: - # Find the index of the table called _arg, if any - index = next((i for i,entry in enumerate(self.table_entries) if entry.name == arg), None) - if index is None: - self.print(self.ERROR_NO_SUCH_TABLE, arg) - return - self.set_table_index(index) - return - self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text, line, begidx, endidx): - return [ - entry.name - for entry in self.table_entries - if entry.name.startswith(text) - ] - def do_previous(self, _arg): - "Go to the previous table" - if not self.set_table_index(self.table_index - 1): - self.print(self.ERROR_ALREADY_AT_START) - def _set_type(self, name, query, comment): - if len(self.table_entries) <= self.table_index: - return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - entry.new_type = MissingnessType( - name=name, - query=query, - comment=comment, - columns=self.get_nonnull_columns(entry.name), - ) - def _set_none(self): - if len(self.table_entries) <= self.table_index: - return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - entry.new_type = None - def do_sampled(self, arg: str): - """ - Set the current table missingness as 'sampled', and go to the next table. - "sampled 3000" means sample 3000 rows at random and choose the missingness - to be the same as one of those 3000 at random. - "sampled" means the same, but with a default number of rows sampled (1000). - """ - if len(self.table_entries) <= self.table_index: - self.print("Error! not on a table") - return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - if arg == "": - count = 1000 - elif arg.isdecimal(): - count = int(arg) - else: - self.print("Error: sampled can be used alone or with an integer argument. {0} is not permitted", arg) - return - self._set_type( - MissingnessType.SAMPLED, - MissingnessType.sampled_query( - entry.name, - count, - self.get_nonnull_columns(entry.name), - ), - f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}" - ) - self.print("Table {} set to sampled missingness", self.table_name()) - self.next_table() - def do_none(self, _arg): - "Set the current table to have no missingness, and go to the next table" - self._set_none() - self.print("Table {} set to have no missingness", self.table_name()) - self.next_table() - - -def update_missingness(src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): - with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: - mc.cmdloop() - return mc.config - - -@dataclass -class GeneratorInfo: - columns: list[str] - gen: Generator | None - -@dataclass -class GeneratorCmdTableEntry(TableEntry): - old_generators: list[GeneratorInfo] - new_generators: list[GeneratorInfo] - -class GeneratorCmd(DbCmd): - intro = "Interactive generator configuration. Type ? for help.\n" - doc_leader = """Use command 'propose' for a list of generators applicable to the -current column, then command 'compare' to see how these perform -against the source data, then command 'set' to choose your favourite. -Use 'unset' to remove the column's generator. Use commands 'next' and -'previous' to change which column we are examining. Use 'info' -for useful information about the current column. Use 'tables' and -'list' to see available tables and columns. Use 'columns' to see -information about the columns in the current table. Use 'peek', -'count' or 'select' to fetch data from the source database. Use -'quit' to exit this program.""" - prompt = "(generatorconf) " - file = None - - PROPOSE_SOURCE_SAMPLE_TEXT = "Sample of actual source data: {0}..." - PROPOSE_SOURCE_EMPTY_TEXT = "Source database has no data in this column." - PROPOSE_GENERATOR_SAMPLE_TEXT = "{index}. {name}: {fit} {sample} ..." - PRIMARY_PRIVATE_TEXT = "Primary Private" - SECONDARY_PRIVATE_TEXT = "Secondary Private on columns {0}" - NOT_PRIVATE_TEXT = "Not private" - ERROR_NO_SUCH_TABLE = "No such (non-vocabulary, non-ignored) table name {0}" - ERROR_NO_SUCH_COLUMN = "No such column {0} in this table" - ERROR_COLUMN_ALREADY_MERGED = "Column {0} is already merged" - ERROR_COLUMN_ALREADY_UNMERGED = "Column {0} is not merged" - ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" - PROPOSE_NOTHING = "No proposed generators, sorry." - - SRC_STAT_RE = re.compile(r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?') - - def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None: - if table.get("ignore", False): - return None - if table.get("vocabulary_table", False): - return None - if table.get("num_rows_per_pass", 1) == 0: - return None - metadata_table = self.metadata.tables[table_name] - columns = [str(colname) for colname in metadata_table.columns.keys()] - column_set = frozenset(columns) - columns_assigned_so_far = set() - new_generator_infos: list[GeneratorInfo] = [] - old_generator_infos: list[GeneratorInfo] = [] - for rg in table.get("row_generators", []): - gen_name = rg.get("name", None) - if gen_name: - ca = rg.get("columns_assigned", []) - collist: list[str] = [ca] if isinstance(ca, str) else [str(c) for c in ca] - colset: set[str] = set(collist) - for unknown in colset - column_set: - logger.warning( - "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, gen_name, unknown - ) - for mult in columns_assigned_so_far & colset: - logger.warning( - "table '%s' has column '%s' assigned to multiple times", table_name, mult - ) - actual_collist = [c for c in collist if c in columns] - if actual_collist: - gen = PredefinedGenerator(table, rg, self.config) - new_generator_infos.append(GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - )) - old_generator_infos.append(GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - )) - columns_assigned_so_far |= colset - for colname in columns: - if colname not in columns_assigned_so_far: - new_generator_infos.append(GeneratorInfo( - columns=[colname], - gen=None, - )) - if len(new_generator_infos) == 0: - return None - return GeneratorCmdTableEntry( - name=table_name, - old_generators=old_generator_infos, - new_generators=new_generator_infos, - ) - - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): - super().__init__(src_dsn, src_schema, metadata, config) - self.generator_index = 0 - self.generators_valid_columns = None - self.set_prompt() - - def set_table_index(self, index): - ret = super().set_table_index(index) - if ret: - self.generator_index = 0 - self.set_prompt() - return ret - - def previous_table(self): - ret = self.set_table_index(self.table_index - 1) - if ret: - table = self.get_table() - if table is None: - self.print("Internal error! table {0} does not have any generators!", self.table_index) - return False - self.generator_index = len(table.new_generators) - 1 - else: - self.print(self.ERROR_ALREADY_AT_START) - return ret - - def get_table(self) -> GeneratorCmdTableEntry | None: - if self.table_index < len(self.table_entries): - return self.table_entries[self.table_index] - return None - - def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: - if self.table_index < len(self.table_entries): - entry: GeneratorCmdTableEntry = self.table_entries[self.table_index] - if self.generator_index < len(entry.new_generators): - return (entry.name, entry.new_generators[self.generator_index]) - return (entry.name, None) - return (None, None) - - def get_column_names(self) -> list[str]: - (_, generator_info) = self.get_table_and_generator() - return generator_info.columns if generator_info else [] - - def column_metadata(self) -> list[Column]: - table = self.table_metadata() - if table is None: - return [] - return [ - table.columns[name] - for name in self.get_column_names() - ] - - def set_prompt(self): - (table_name, gen_info) = self.get_table_and_generator() - if table_name is None: - self.prompt = "(generators) " - return - if gen_info is None: - self.prompt = "({table}) ".format(table=table_name) - return - table = self.table_metadata() - columns = [ - c + "[pk]" if table.columns[c].primary_key else c - for c in gen_info.columns - ] - gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" - self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - - def _remove_auto_src_stats(self) -> list[dict[str, any]]: - return self._remove_prefix_src_stats("auto__") - - def _copy_entries(self) -> None: - src_stats = self._remove_auto_src_stats() - tes: list[GeneratorCmdTableEntry] = self.table_entries - for entry in tes: - rgs = [] - new_gens: list[Generator] = [] - for generator in entry.new_generators: - if generator.gen is not None: - new_gens.append(generator.gen) - cqs = generator.gen.custom_queries() - for cq_key, cq in cqs.items(): - src_stats.append({ - "name": cq_key, - "query": cq["query"], - "comments": [cq["comment"]] if "comment" in cq and cq["comment"] else [], - }) - rg = { - "name": generator.gen.function_name(), - "columns_assigned": generator.columns, - } - kwn = generator.gen.nominal_kwargs() - if kwn: - rg["kwargs"] = kwn - rgs.append(rg) - aq = self._get_aggregate_query(new_gens, entry.name) - if aq: - src_stats.append({ - "name": f"auto__{entry.name}", - "query": aq, - "comments": [ - q["comment"] - for gen in new_gens - for q in gen.select_aggregate_clauses().values() - if "comment" in q and q["comment"] is not None - ], - }) - table_config = self.get_table_config(entry.name) - if rgs: - table_config["row_generators"] = rgs - elif "row_generators" in table_config: - del table_config["row_generators"] - self.set_table_config(entry.name, table_config) - self.config["src-stats"] = src_stats - - def _find_old_generator(self, entry: GeneratorCmdTableEntry, columns) -> Generator | None: - """ Find any generator that previously assigned to these exact same columns. """ - fc = frozenset(columns) - for gen in entry.old_generators: - if frozenset(gen.columns) == fc: - return gen.gen - return None - - def do_quit(self, arg): - "Check the updates, save them if desired and quit the configurer." - count = 0 - for entry in self.table_entries: - header_shown = False - g_entry: GeneratorCmdTableEntry = entry - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - new_gen = None if gen is None else gen.gen - if old_gen != new_gen: - if not header_shown: - header_shown = True - self.print("Table {0}:", entry.name) - count += 1 - self.print( - "...changing {0} from {1} to {2}", - ", ".join(gen.columns), - old_gen.name() if old_gen else "nothing", - gen.gen.name() if gen.gen else "nothing", - ) - if count == 0: - self.print("You have made no changes.") - if arg in {"yes", "no"}: - reply = arg - else: - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, arg): - "list the tables" - for entry in self.table_entries: - gen_count = len(entry.new_generators) - how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" - self.print("{0} ({1})", entry.name, how_many) - - def do_list(self, arg): - "list the generators in the current table" - if len(self.table_entries) <= self.table_index: - self.print("Error: no table {0}", self.table_index) - return - g_entry: GeneratorCmdTableEntry = self.table_entries[self.table_index] - table = self.table_metadata() - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - old = "" if old_gen is None else old_gen.name() - if old_gen == gen.gen: - becomes = "" - if old == "": - old = "(not set)" - elif gen.gen is None: - becomes = "(delete)" - else: - becomes = f"->{gen.gen.name()}" - primary = "" - if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: - primary = "[primary-key]" - self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) - - def do_columns(self, _arg): - "Report the column names and metadata" - self.report_columns() - - def do_info(self, _arg): - "Show information about the current column" - for cm in self.column_metadata(): - self.print( - "Column {0} in table {1} has type {2} ({3}).", - cm.name, - cm.table.name, - str(cm.type), - "nullable" if cm.nullable else "not nullable", - ) - if cm.primary_key: - self.print("It is a primary key, which usually does not need a generator (it will auto-increment)") - if cm.foreign_keys: - fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] - self.print("It is a foreign key referencing column {0}", ", ".join(fk_names)) - if len(fk_names) == 1 and not cm.primary_key: - self.print("You do not need a generator if you just want a uniform choice over the referenced table's rows") - - def _get_table_index(self, table_name: str) -> int | None: - for n, entry in enumerate(self.table_entries): - if entry.name == table_name: - return n - return None - - def _get_generator_index(self, table_index, column_name): - entry: GeneratorCmdTableEntry = self.table_entries[table_index] - for n, gen in enumerate(entry.new_generators): - if column_name in gen.columns: - return n - return None - - def go_to(self, target): - parts = target.split(".", 1) - table_index = self._get_table_index(parts[0]) - if table_index is None: - if len(parts) == 1: - gen_index = self._get_generator_index(self.table_index, parts[0]) - if gen_index is not None: - self.generator_index = gen_index - self.set_prompt() - return True - self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, parts[0]) - return False - gen_index = None - if 1 < len(parts) and parts[1]: - gen_index = self._get_generator_index(table_index, parts[1]) - if gen_index is None: - self.print("we cannot set the generator for column {0}", parts[1]) - return False - self.set_table_index(table_index) - if gen_index is not None: - self.generator_index = gen_index - self.set_prompt() - return True - - def do_next(self, arg): - """ - Go to the next generator. - Or go to a named table: 'next tablename'. - Or go to a column: 'next tablename.columnname'. - Or go to a column within this table: 'next columnname'. - """ - if arg: - self.go_to(arg) - else: - self._go_next() - - def do_n(self, arg): - """ Synonym for next """ - self.do_next(arg) - - def complete_n(self, text: str, line: str, begidx: int, endidx: int): - return self.complete_next(text, line, begidx, endidx) - - def _go_next(self): - table = self.get_table() - if table is None: - self.print("No more tables") - next_gi = self.generator_index + 1 - if next_gi == len(table.new_generators): - self.next_table(self.INFO_NO_MORE_TABLES) - return - self.generator_index = next_gi - self.set_prompt() - - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): - parts = text.split(".", 1) - first_part = parts[0] - if 1 < len(parts): - column_name = parts[1] - table_index = self._get_table_index(first_part) - if table_index is None: - return [] - table_entry: GeneratorCmdTableEntry = self.table_entries[table_index] - return [ - f"{first_part}.{column}" - for gen in table_entry.new_generators - for column in gen.columns - if column.startswith(column_name) - ] - table_names = [ - entry.name - for entry in self.table_entries - if entry.name.startswith(first_part) - ] - if first_part in table_names: - table_names.append(f"{first_part}.") - current_table = self.get_table() - if current_table: - column_names = [ - col - for gen in current_table.new_generators - for col in gen.columns - if col.startswith(first_part) - ] - else: - column_names = [] - return table_names + column_names - - def do_previous(self, _arg): - """ Go to the previous generator """ - if self.generator_index == 0: - self.previous_table() - else: - self.generator_index -= 1 - self.set_prompt() - - def do_b(self, arg): - """ Synonym for previous """ - self.do_previous(arg) - - def _generators_valid(self) -> bool: - return self.generators_valid_columns == (self.table_index, self.get_column_names()) - - def _get_generator_proposals(self) -> list[Generator]: - if not self._generators_valid(): - self.generators = None - if self.generators is None: - columns = self.column_metadata() - gens = everything_factory().get_generators(columns, self.engine) - gens.sort(key=lambda g: g.fit(9999)) - self.generators = gens - self.generators_valid_columns = (self.table_index, self.get_column_names().copy()) - return self.generators - - def _print_privacy(self): - table = self.table_metadata() - if table is None: - return - if table_is_private(self.config, table.name): - self.print(self.PRIMARY_PRIVATE_TEXT) - return - pfks = primary_private_fks(self.config, table) - if not pfks: - self.print(self.NOT_PRIVATE_TEXT) - return - self.print(self.SECONDARY_PRIVATE_TEXT, pfks) - - def do_compare(self, arg: str): - """ - Compare the real data with some generators. - - 'compare': just look at some source data from this column. - 'compare 5 6 10': compare a sample of the source data with a sample - from generators 5, 6 and 10. You can find out which numbers - correspond to which generators using the 'propose' command. - """ - self._print_privacy() - args = arg.split() - limit = 20 - comparison = { - "source": [ - x[0] if len(x) == 1 else ", ".join(x) - for x in self._get_column_data(limit, to_str=str) - ] - } - gens: list[Generator] = self._get_generator_proposals() - table_name = self.table_name() - for argument in args: - if argument.isdigit(): - n = int(argument) - if 0 < n and n <= len(gens): - gen = gens[n - 1] - comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) - self._print_values_queried(table_name, n, gen) - self.print_table_by_columns(comparison) - - def do_c(self, arg): - """ Synonym for compare. """ - self.do_compare(arg) - - def _print_values_queried(self, table_name: str, n: int, gen: Generator): - """ - Print the values queried from the database for this generator. - """ - if not gen.select_aggregate_clauses() and not gen.custom_queries(): - self.print( - "{0}. {1} requires no data from the source database.", - n, - gen.name(), - ) - else: - self.print( - "{0}. {1} requires the following data from the source database:", - n, - gen.name(), - ) - self._print_select_aggregate_query(table_name, gen) - self._print_custom_queries(gen) - - def _print_custom_queries(self, gen: Generator) -> None: - """ - Print all the custom queries and all the values they get in this case. - """ - cqs = gen.custom_queries() - if not cqs: - return - cq_key2args = {} - nominal = gen.nominal_kwargs() - actual = gen.actual_kwargs() - self._get_custom_queries_from( - cq_key2args, - nominal, - actual, - ) - for cq_key, cq in cqs.items(): - self.print("{0}; providing the following values: {1}", cq["query"], cq_key2args[cq_key]) - - def _get_custom_queries_from(self, out, nominal, actual): - if type(nominal) is str: - src_stat_groups = self.SRC_STAT_RE.search(nominal) - if src_stat_groups: - cq_key = src_stat_groups.group(1) - if cq_key not in out: - out[cq_key] = [] - sub = src_stat_groups.group(3) - if sub: - actual = {sub: actual} - out[cq_key].append(actual) - elif type(nominal) is list and type(actual) is list: - for i in range(min(len(nominal), len(actual))): - self._get_custom_queries_from(out, nominal[i], actual[i]) - elif type(nominal) is dict and type(actual) is dict: - for k, v in nominal.items(): - if k in actual: - self._get_custom_queries_from(out, v, actual[k]) - - def _get_aggregate_query(self, gens: list[Generator], table_name: str) -> str | None: - clauses = [ - f'{q["clause"]} AS {n}' - for gen in gens - for n, q in or_default(gen.select_aggregate_clauses(), {}).items() - ] - if not clauses: - return None - return f"SELECT {', '.join(clauses)} FROM {table_name}" - - def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: - """ - Prints the select aggregate query and all the values it gets in this case. - """ - sacs = gen.select_aggregate_clauses() - if not sacs: - return - kwa = gen.actual_kwargs() - vals = [] - src_stat2kwarg = { v: k for k, v in gen.nominal_kwargs().items() } - for n in sacs.keys(): - src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' - if src_stat in src_stat2kwarg: - ak = src_stat2kwarg[src_stat] - if ak in kwa: - vals.append(kwa[ak]) - else: - logger.warning("actual_kwargs for %s does not report %s", gen.name(), ak) - else: - logger.warning('nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', gen.name(), table_name, n) - select_q = self._get_aggregate_query([gen], table_name) - self.print("{0}; providing the following values: {1}", select_q, vals) - - def _get_column_data(self, count: int, to_str=repr): - columns = self.get_column_names() - columns_string = ", ".join(columns) - pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) - with self.engine.connect() as connection: - result = connection.execute( - text(f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}") - ) - return [ - [to_str(x) for x in xs] - for xs in result.all() - ] - - def do_propose(self, _arg): - """ - Display a list of possible generators for this column. - - They will be listed in order of fit, the most likely matches first. - The results can be compared (against a sample of the real data in - the column and against each other) with the 'compare' command. - """ - limit = 5 - gens = self._get_generator_proposals() - sample = self._get_column_data(limit) - if sample: - rep = [ - x[0] if len(x) == 1 else ",".join(x) - for x in sample - ] - self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) - else: - self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) - if not gens: - self.print(self.PROPOSE_NOTHING) - for index, gen in enumerate(gens): - fit = gen.fit() - if fit is None: - fit_s = "(no fit)" - elif fit < 100: - fit_s = f"(fit: {fit:.3g})" - else: - fit_s = f"(fit: {fit:.0f})" - self.print( - self.PROPOSE_GENERATOR_SAMPLE_TEXT, - index=index + 1, - name=gen.name(), - fit=fit_s, - sample="; ".join(map(repr, gen.generate_data(limit))) - ) - - def do_p(self, arg): - """ Synonym for propose """ - self.do_propose(arg) - - def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: - for gen in self._get_generator_proposals(): - if gen.name() == gen_name: - return gen - return None - - def do_set(self, arg: str): - """ - Set one of the proposals as a generator. - Takes a single integer argument. - """ - if arg.isdigit() and not self._generators_valid(): - self.print("Please run 'propose' before 'set '") - return - gens = self._get_generator_proposals() - if arg.isdigit(): - index = int(arg) - if index < 1: - self.print("set's integer argument must be at least 1") - return - if len(gens) < index: - self.print( - "There are currently only {0} generators proposed, please select one of them.", - len(gens), - ) - return - new_gen = gens[index - 1] - else: - new_gen = self.get_proposed_generator_by_name(arg) - if new_gen is None: - self.print("'{0}' is not an appropriate generator for this column", arg) - return - self.set_generator(new_gen) - self._go_next() - - def set_generator(self, gen: Generator): - (table, gen_info) = self.get_table_and_generator() - if table is None: - self.print("Error: no table") - return - if gen_info is None: - self.print("Error: no column") - return - gen_info.gen = gen - - def do_s(self, arg): - """ Synonym for set """ - self.do_set(arg) - - def do_unset(self, _arg): - """ - Remove any generator set for this column. - """ - self.set_generator(None) - self._go_next() - - def merge_columns(self, arg: str) -> bool: - """ - Add this column(s) to the specified column(s), so one generator covers them all. - :return: True if everything worked, False if there is an error - """ - cols = arg.split() - if not cols: - self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry = self.get_table() - if table_entry is None: - self.print(self.ERROR_NO_SUCH_TABLE) - return False - cols_available = functools.reduce(lambda x, y: x | y, [ - frozenset(gen.columns) - for gen in table_entry.new_generators - ]) - cols_to_merge = frozenset(cols) - unknown_cols = cols_to_merge - cols_available - if unknown_cols: - for uc in unknown_cols: - self.print(self.ERROR_NO_SUCH_COLUMN, uc) - return False - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - stated_current_columns = cols_to_merge & current_columns - if stated_current_columns: - for c in stated_current_columns: - self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) - return False - # Remove cols_to_merge from each generator - new_new_generators: list[GeneratorInfo] = [] - for gen in table_entry.new_generators: - if gen is gen_info: - # Add columns to this generator - self.generator_index = len(new_new_generators) - new_new_generators.append( - GeneratorInfo( - columns=gen.columns + cols, - gen=None, - ) - ) - else: - # Remove columns if applicable - new_columns = [c for c in gen.columns if c not in cols_to_merge] - is_changed = len(new_columns) != len(gen.columns) - if new_columns: - # We have not removed this generator completely - new_new_generators.append( - GeneratorInfo( - columns=new_columns, - gen=None if is_changed else gen.gen, - ) - ) - table_entry.new_generators = new_new_generators - self.set_prompt() - return True - - def do_merge(self, arg: str): - """ Add this column(s) to the specified column(s), so one generator covers them all. """ - self.merge_columns(arg) - - def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): - last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry = self.get_table() - if table_entry is None: - return [] - return [ - column - for i, gen in enumerate(table_entry.new_generators) - if i != self.generator_index - for column in gen.columns - if column.startswith(last_arg) - ] - - def do_unmerge(self, arg: str): - """ Remove this column(s) from this generator, make them a separate generator. """ - cols = arg.split() - if not cols: - self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry = self.get_table() - if table_entry is None: - self.print(self.ERROR_NO_SUCH_TABLE) - return - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - cols_to_unmerge = frozenset(cols) - unknown_cols = cols_to_unmerge - current_columns - if unknown_cols: - for uc in unknown_cols: - self.print(self.ERROR_NO_SUCH_COLUMN, uc) - return - stated_unmerged_columns = cols_to_unmerge - current_columns - if stated_unmerged_columns: - for c in stated_unmerged_columns: - self.print(self.ERROR_COLUMN_ALREADY_UNMERGED, c) - return - if cols_to_unmerge == current_columns: - self.print(self.ERROR_CANNOT_UNMERGE_ALL) - return - # Remove unmerged columns - for um in cols_to_unmerge: - gen_info.columns.remove(um) - # The existing generator will not work - gen_info.gen = None - # And put them into a new (empty) generator - table_entry.new_generators.insert(self.generator_index + 1, GeneratorInfo( - columns=cols, - gen=None, - )) - self.set_prompt() - - def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int): - last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry = self.get_table() - if table_entry is None: - return [] - return [ - column - for column in table_entry.new_generators[self.generator_index].columns - if column.startswith(last_arg) - ] - - def get_current_columns(self) -> set[str]: - table_entry: GeneratorCmdTableEntry = self.get_table() - gen_info = table_entry.new_generators[self.generator_index] - return set(gen_info.columns) - - def set_merged_columns(self, first_col: str, other_cols: str) -> bool: - """ - Merge columns, after unmerging everything we don't want - :param first_col: The first column we want in the merge, must already - be in this column set. - :param other_cols: all the columns we want merged other than - first_col, in order, space-separated. - :return: True if the merge worked, false if there was an error - """ - existing = self.get_current_columns() - existing.discard(first_col) - for to_remove in existing: - self.do_unmerge(to_remove) - return self.merge_columns(other_cols) - - -def try_setting_generator(gc: GeneratorCmd, gens: Iterable[str]) -> bool: - for gen in gens: - new_gen = gc.get_proposed_generator_by_name(gen) - if new_gen is not None: - gc.set_generator(new_gen) - return True - return False - - -def update_config_generators( - src_dsn: str, - src_schema: str, - metadata: MetaData, - config: Mapping, - spec_path: Path | None, -): - with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: - if spec_path is None: - gc.cmdloop() - return gc.config - spec = spec_path.open() - line_no = 0 - for line in csv.reader(spec): - line_no += 1 - if line: - if len(line) < 3: - logger.error("line {0} of file {1} has fewer than three values", line_no, spec_path) - cols = line[1].split(maxsplit=1) - if gc.go_to(f"{line[0]}.{cols[0]}"): - if len(cols) == 1 or gc.set_merged_columns(cols[0], cols[1]): - try_setting_generator(gc, itertools.islice(line, 2, None)) - else: - logger.warning("no such column {0}[{1}]", line[0], line[1]) - gc.do_quit("yes") - return gc.config diff --git a/datafaker/interactive/__init__.py b/datafaker/interactive/__init__.py new file mode 100644 index 00000000..c279720f --- /dev/null +++ b/datafaker/interactive/__init__.py @@ -0,0 +1,100 @@ +"""Interactive configuration commands.""" +import csv +import itertools +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Any + +from sqlalchemy import MetaData + +from datafaker.interactive.generators import GeneratorCmd, try_setting_generator +from datafaker.interactive.missingness import MissingnessCmd +from datafaker.interactive.table import TableCmd +from datafaker.utils import logger + +# Monkey patch pyreadline3 v3.5 so that it works with Python 3.13 +# Windows users can install pyreadline3 to get tab completion working. +# See https://github.com/pyreadline3/pyreadline3/issues/37 +try: + import readline + + if not hasattr(readline, "backend"): + setattr(readline, "backend", "readline") +except ImportError: + pass + + +def update_config_tables( + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping +) -> Mapping[str, Any]: + """Ask the user to specify what should happen to each table.""" + with TableCmd(src_dsn, src_schema, metadata, config) as tc: + tc.cmdloop() + return tc.config + + +def update_missingness( + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], +) -> Mapping[str, Any]: + """ + Ask the user to update the missingness information in ``config.yaml``. + + :param src_dsn: The connection string for the source database. + :param src_schema: The name of the source database schema (or None + for the default). + :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. + :param config: The starting configuration, + :return: The updated configuration. + """ + with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: + mc.cmdloop() + return mc.config + + +def update_config_generators( + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + spec_path: Path | None, +) -> Mapping[str, Any]: + """ + Update configuration with the specification from a CSV file. + + The specification is a headerless CSV file with columns: Table name, + Column name (or space-separated list of column names), Generator + name required, Second choice generator name, Third choice generator + name, etcetera. + :param src_dsn: Address of the source database + :param src_schema: Name of the source database schema to read from + :param metadata: SQLAlchemy representation of the source database + :param config: Existing configuration (will be destructively updated) + :param spec_path: The path of the CSV file containing the specification + :return: Updated configuration. + """ + with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: + if spec_path is None: + gc.cmdloop() + return gc.config + spec = spec_path.open() + line_no = 0 + for line in csv.reader(spec): + line_no += 1 + if line: + if len(line) < 3: + logger.error( + "line %d of file %s has fewer than three values", + line_no, + spec_path, + ) + cols = line[1].split(maxsplit=1) + if gc.go_to(f"{line[0]}.{cols[0]}"): + if len(cols) == 1 or gc.set_merged_columns(cols[0], cols[1]): + try_setting_generator(gc, itertools.islice(line, 2, None)) + else: + logger.warning("no such column %s[%s]", line[0], line[1]) + gc.do_quit("yes") + return gc.config diff --git a/datafaker/interactive/base.py b/datafaker/interactive/base.py new file mode 100644 index 00000000..9d612a7c --- /dev/null +++ b/datafaker/interactive/base.py @@ -0,0 +1,410 @@ +"""Base configuration command shells.""" +import cmd +from abc import ABC, abstractmethod +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass +from enum import Enum +from types import TracebackType +from typing import Any, Optional, Type + +import sqlalchemy +from prettytable import PrettyTable +from sqlalchemy import Engine, ForeignKey, MetaData, Table +from typing_extensions import Self + +from datafaker.utils import ( + T, + create_db_engine, + fk_refers_to_ignored_table, + get_sync_engine, +) + + +def or_default(v: T | None, d: T) -> T: + """Return v if it isn't None, otherwise d.""" + return d if v is None else v + + +class TableType(Enum): + """Types of table to be configured.""" + + GENERATE = "generate" + IGNORE = "ignore" + VOCABULARY = "vocabulary" + PRIVATE = "private" + EMPTY = "empty" + + +TYPE_LETTER = { + TableType.GENERATE: "G", + TableType.IGNORE: "I", + TableType.VOCABULARY: "V", + TableType.PRIVATE: "P", + TableType.EMPTY: "e", +} + +TYPE_PROMPT = { + TableType.GENERATE: "(table: {}) ", + TableType.IGNORE: "(table: {} (ignore)) ", + TableType.VOCABULARY: "(table: {} (vocab)) ", + TableType.PRIVATE: "(table: {} (private)) ", + TableType.EMPTY: "(table: {} (empty))", +} + + +@dataclass +class TableEntry: + """Base class for table entries for interactive commands.""" + + name: str # name of the table + + +class AskSaveCmd(cmd.Cmd): + """Interactive shell for whether to save and quit.""" + + intro = "Do you want to save this configuration?" + prompt = "(yes/no/cancel) " + file = None + + def __init__(self) -> None: + """Initialise a save command.""" + super().__init__() + self.result = "" + + def do_yes(self, _arg: str) -> bool: + """Save the new config.yaml.""" + self.result = "yes" + return True + + def do_no(self, _arg: str) -> bool: + """Exit without saving.""" + self.result = "no" + return True + + def do_cancel(self, _arg: str) -> bool: + """Do not exit.""" + self.result = "cancel" + return True + + +def fk_column_name(fk: ForeignKey) -> str: + """Display name for a foreign key.""" + if fk_refers_to_ignored_table(fk): + return f"{fk.target_fullname} (ignored)" + return str(fk.target_fullname) + + +class DbCmd(ABC, cmd.Cmd): + """Base class for interactive configuration commands.""" + + INFO_NO_MORE_TABLES = "There are no more tables" + ERROR_ALREADY_AT_START = "Error: Already at the start" + ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" + ERROR_NO_SUCH_TABLE_OR_COLUMN = ( + "Error: '{0}' is not the name of a table" + " in this database or a column in this table" + ) + ROW_COUNT_MSG = "Total row count: {}" + + @abstractmethod + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableEntry | None: + """ + Make a table entry suitable for this interactive command. + + :param name: The name of the table to make an entry for. + :param table_config: The part of the ``config.yaml`` referring to this table. + :return: The table entry or None if this table should not be interacted with. + """ + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ): + """Initialise a DbCmd.""" + super().__init__() + self.config: MutableMapping[str, Any] = config + self.metadata = metadata + self._table_entries: list[TableEntry] = [] + tables_config: MutableMapping = config.get("tables", {}) + if not isinstance(tables_config, MutableMapping): + tables_config = {} + for name in metadata.tables.keys(): + table_config = tables_config.get(name, {}) + if not isinstance(table_config, MutableMapping): + table_config = {} + entry = self.make_table_entry(name, table_config) + if entry is not None: + self._table_entries.append(entry) + self.table_index = 0 + self.engine = create_db_engine(src_dsn, schema_name=src_schema) + + @property + def sync_engine(self) -> Engine: + """Get the synchronous version of the engine.""" + return get_sync_engine(self.engine) + + def __enter__(self) -> Self: + """Enter a ``with`` statement.""" + return self + + def __exit__( + self, + _exc_type: Optional[Type[BaseException]], + _exc_val: Optional[BaseException], + _exc_tb: Optional[TracebackType], + ) -> None: + """Dispose of this object.""" + self.engine.dispose() + + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Print text, formatted with positional and keyword arguments.""" + print(text.format(*args, **kwargs)) + + def print_table( + self, headings: Sequence[str], rows: Sequence[Sequence[Any]] + ) -> None: + """ + Print a table. + + :param headings: List of headings for the table. + :param rows: List of rows of values. + """ + output = PrettyTable() + output.field_names = headings + for row in rows: + # Hopefully PrettyTable will accept Sequence in the future, not list + output.add_row(list(row)) + print(output) + + def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: + """ + Print a table. + + :param columns: Dict of column names to the values in the column. + """ + output = PrettyTable() + row_count = max(len(col) for col in columns.values()) + for field_name, data in columns.items(): + output.add_column(field_name, list(data) + [None] * (row_count - len(data))) + print(output) + + def print_results(self, result: sqlalchemy.CursorResult) -> None: + """Print the rows resulting from a database query.""" + self.print_table(list(result.keys()), [list(row) for row in result.all()]) + + def ask_save(self) -> str: + """ + Ask the user if they want to save. + + :return: ``yes``, ``no`` or ``cancel``. + """ + ask = AskSaveCmd() + ask.cmdloop() + return ask.result + + @abstractmethod + def set_prompt(self) -> None: + """Set the prompt according to the current state.""" + + def _set_table_index(self, index: int) -> bool: + """ + Move to a different table. + + :param index: Index of the table to move to. + :return: True if there is a table with such an index to move to. + """ + if 0 <= index < len(self._table_entries): + self.table_index = index + self.set_prompt() + return True + return False + + def next_table(self, report: str = "No more tables") -> bool: + """ + Move to the next table. + + :param report: The text to print if there is no next table. + :return: True if there is another table to move to. + """ + if not self._set_table_index(self.table_index + 1): + self.print(report) + return False + return True + + def table_name(self) -> str: + """Get the name of the current table.""" + return str(self._table_entries[self.table_index].name) + + def table_metadata(self) -> Table: + """Get the metadata of the current table.""" + return self.metadata.tables[self.table_name()] + + def _get_column_names(self) -> list[str]: + """Get the names of the current columns.""" + return [col.name for col in self.table_metadata().columns] + + def report_columns(self) -> None: + """Print information about the current columns.""" + self.print_table( + ["name", "type", "primary", "nullable", "foreign key"], + [ + [ + name, + str(col.type), + col.primary_key, + col.nullable, + ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), + ] + for name, col in self.table_metadata().columns.items() + ], + ) + + def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: + """Get the configuration of the named table.""" + ts = self.config.get("tables", None) + if not isinstance(ts, MutableMapping): + return {} + t = ts.get(table_name) + return t if isinstance(t, MutableMapping) else {} + + def set_table_config( + self, table_name: str, config: MutableMapping[str, Any] + ) -> None: + """Set the configuration of the named table.""" + ts = self.config.get("tables", None) + if not isinstance(ts, MutableMapping): + self.config["tables"] = {table_name: config} + return + ts[table_name] = config + + def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: + """Remove all source stats with the given prefix from the configuration.""" + src_stats = self.config.get("src-stats", []) + new_src_stats = [] + for stat in src_stats: + if not stat.get("name", "").startswith(prefix): + new_src_stats.append(stat) + self.config["src-stats"] = new_src_stats + return new_src_stats + + def get_nonnull_columns(self, table_name: str) -> list[str]: + """Get the names of the nullable columns in the named table.""" + metadata_table = self.metadata.tables[table_name] + return [ + str(name) + for name, column in metadata_table.columns.items() + if column.nullable + ] + + def find_entry_index_by_table_name(self, table_name: str) -> int | None: + """Get the index of the table entry of the named table.""" + return next( + ( + i + for i, entry in enumerate(self._table_entries) + if entry.name == table_name + ), + None, + ) + + def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + """Get the table entry of the named table.""" + for e in self._table_entries: + if e.name == table_name: + return e + return None + + def do_counts(self, _arg: str) -> None: + """Report the column names with the counts of nulls in them.""" + if len(self._table_entries) <= self.table_index: + return + table_name = self.table_name() + nonnull_columns = self.get_nonnull_columns(table_name) + colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns] + with self.sync_engine.connect() as connection: + result = ( + connection.execute( + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + ) + ) + .mappings() + .first() + ) + if result is None: + self.print("Could not count rows in table {0}", table_name) + return + row_count = result.get("row_count", 0) + self.print(self.ROW_COUNT_MSG, row_count) + self.print_table( + ["Column", "NULL count"], + [ + [name, row_count - count] + for name, count in result.items() + if name != "row_count" + ], + ) + + def do_select(self, arg: str) -> None: + """Run a select query over the database and show the first 50 results.""" + max_select_rows = 50 + with self.sync_engine.connect() as connection: + try: + result = connection.execute(sqlalchemy.text("SELECT " + arg)) + except sqlalchemy.exc.DatabaseError as exc: + self.print("Failed to execute: {}", exc) + return + row_count = result.rowcount + self.print(self.ROW_COUNT_MSG, row_count) + if 50 < row_count: + self.print("Showing the first {} rows", max_select_rows) + fields = list(result.keys()) + rows = result.fetchmany(max_select_rows) + self.print_table(fields, rows) + + def do_peek(self, arg: str) -> None: + """ + View some data from the current table. + + Use 'peek col1 col2 col3' to see a sample of values from + columns col1, col2 and col3 in the current table. + Use 'peek' to see a sample of the current column(s). + Rows that are enitrely null are suppressed. + """ + max_peek_rows = 25 + if len(self._table_entries) <= self.table_index: + return + table_name = self.table_name() + col_names = arg.split() + if not col_names: + col_names = self._get_column_names() + nonnulls = [cn + " IS NOT NULL" for cn in col_names] + with self.sync_engine.connect() as connection: + cols = ",".join(col_names) + where = "WHERE" if nonnulls else "" + nonnull = " OR ".join(nonnulls) + query = sqlalchemy.text( + f"SELECT {cols} FROM {table_name} {where} {nonnull}" + f" ORDER BY RANDOM() LIMIT {max_peek_rows}" + ) + try: + result = connection.execute(query) + except sqlalchemy.exc.SQLAlchemyError as exc: + self.print(f'SQL query "{query}" caused exception {exc}') + return + self.print_table(list(result.keys()), result.fetchmany(max_peek_rows)) + + def complete_peek( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Completions for the ``peek`` command.""" + if len(self._table_entries) <= self.table_index: + return [] + return [ + col for col in self.table_metadata().columns.keys() if col.startswith(text) + ] diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py new file mode 100644 index 00000000..24ccb529 --- /dev/null +++ b/datafaker/interactive/generators.py @@ -0,0 +1,1024 @@ +"""Generator configuration shell.""" # pylint: disable=too-many-lines +import functools +import re +from collections.abc import Iterable, Mapping, MutableMapping, Sequence +from dataclasses import dataclass +from typing import Any, Callable, Optional, cast + +import sqlalchemy +from sqlalchemy import Column, MetaData + +from datafaker.generators import everything_factory +from datafaker.generators.base import Generator, PredefinedGenerator +from datafaker.interactive.base import DbCmd, TableEntry, fk_column_name, or_default +from datafaker.utils import ( + get_columns_assigned, + get_row_generators, + logger, + primary_private_fks, + table_is_private, +) + + +@dataclass +class GeneratorInfo: + """A generator and the columns it assigns to.""" + + columns: list[str] + gen: Generator | None + + +@dataclass +class GeneratorCmdTableEntry(TableEntry): + """ + List of generators set for a table. + + Includes the original setting and the currently configured + generators. + """ + + old_generators: list[GeneratorInfo] + new_generators: list[GeneratorInfo] + + +# pylint: disable=too-many-public-methods +class GeneratorCmd(DbCmd): + """Interactive command shell for setting generators.""" + + intro = "Interactive generator configuration. Type ? for help.\n" + doc_leader = """Use command 'propose' for a list of generators applicable to the +current column, then command 'compare' to see how these perform +against the source data, then command 'set' to choose your favourite. +Use 'unset' to remove the column's generator. Use commands 'next' and +'previous' to change which column we are examining. Use 'info' +for useful information about the current column. Use 'tables' and +'list' to see available tables and columns. Use 'columns' to see +information about the columns in the current table. Use 'peek', +'count' or 'select' to fetch data from the source database. Use +'quit' to exit this program.""" + prompt = "(generatorconf) " + file = None + + PROPOSE_SOURCE_SAMPLE_TEXT = "Sample of actual source data: {0}..." + PROPOSE_SOURCE_EMPTY_TEXT = "Source database has no data in this column." + PROPOSE_GENERATOR_SAMPLE_TEXT = "{index}. {name}: {fit} {sample} ..." + PRIMARY_PRIVATE_TEXT = "Primary Private" + SECONDARY_PRIVATE_TEXT = "Secondary Private on columns {0}" + NOT_PRIVATE_TEXT = "Not private" + ERROR_NO_SUCH_TABLE = "No such (non-vocabulary, non-ignored) table name {0}" + ERROR_NO_SUCH_COLUMN = "No such column {0} in this table" + ERROR_COLUMN_ALREADY_MERGED = "Column {0} is already merged" + ERROR_COLUMN_ALREADY_UNMERGED = "Column {0} is not merged" + ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" + PROPOSE_NOTHING = "No proposed generators, sorry." + + SRC_STAT_RE = re.compile( + r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' + ) + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> GeneratorCmdTableEntry | None: + """ + Make a table entry. + + :param table_name: The name of the table. + :param table: The portion of the ``config.yaml`` file describing this table. + :return: The newly constructed table entry, or None if this table is to be ignored. + """ + if table_config.get("ignore", False): + return None + if table_config.get("vocabulary_table", False): + return None + if table_config.get("num_rows_per_pass", 1) == 0: + return None + columns = [ + str(colname) for colname in self.metadata.tables[table_name].columns.keys() + ] + column_set = frozenset(columns) + columns_assigned_so_far: set[str] = set() + + new_generator_infos: list[GeneratorInfo] = [] + for gen_name, rg in get_row_generators(table_config): + colset: set[str] = set(get_columns_assigned(rg)) + for unknown in colset - column_set: + logger.warning( + "table '%s' has '%s' assigned to column '%s' which is not in this table", + table_name, + gen_name, + unknown, + ) + for mult in columns_assigned_so_far & colset: + logger.warning( + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, + ) + actual_collist = [c for c in columns if c in colset] + if actual_collist: + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=PredefinedGenerator(table_name, rg, self.config), + ) + ) + columns_assigned_so_far |= colset + old_generator_infos = [ + GeneratorInfo(columns=gi.columns.copy(), gen=gi.gen) + for gi in new_generator_infos + ] + for colname in columns: + if colname not in columns_assigned_so_far: + new_generator_infos.append( + GeneratorInfo( + columns=[colname], + gen=None, + ) + ) + if len(new_generator_infos) == 0: + return None + + return GeneratorCmdTableEntry( + name=table_name, + old_generators=old_generator_infos, + new_generators=new_generator_infos, + ) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ) -> None: + """ + Initialise a ``GeneratorCmd``. + + :param src_dsn: connection address for source database + :param src_schema: database schema name + :param metadata: SQLAlchemy metadata for the source database + :param config: Configuration loaded from ``config.yaml`` + """ + super().__init__(src_dsn, src_schema, metadata, config) + self.generators: list[Generator] | None = None + self.generator_index = 0 + self.generators_valid_columns: Optional[tuple[int, list[str]]] = None + self.set_prompt() + + @property + def table_entries(self) -> list[GeneratorCmdTableEntry]: + """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" + return cast(list[GeneratorCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name( + self, table_name: str + ) -> GeneratorCmdTableEntry | None: + """ + Find the table entry by name. + + :param table_name: The name of the table to find. + :return: The table entry, or None if no such table name exists. + """ + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(GeneratorCmdTableEntry, entry) + + def _set_table_index(self, index: int) -> bool: + """ + Move to a new table. + + :param index: table index to move to. + """ + ret = super()._set_table_index(index) + if ret: + self.generator_index = 0 + self.set_prompt() + return ret + + def _previous_table(self) -> bool: + """ + Move to the table before the current one. + + :return: True if there is a previous table to go to. + """ + ret = self._set_table_index(self.table_index - 1) + if ret: + table = self.get_table() + if table is None: + self.print( + "Internal error! table {0} does not have any generators!", + self.table_index, + ) + return False + self.generator_index = len(table.new_generators) - 1 + else: + self.print(self.ERROR_ALREADY_AT_START) + return ret + + def get_table(self) -> GeneratorCmdTableEntry | None: + """Get the current table entry.""" + if self.table_index < len(self.table_entries): + return self.table_entries[self.table_index] + return None + + def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + """Get a pair; the table name then the generator information.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + if self.generator_index < len(entry.new_generators): + return (entry.name, entry.new_generators[self.generator_index]) + return (entry.name, None) + return (None, None) + + def _get_column_names(self) -> list[str]: + """Get the (unqualified) names for all the current columns.""" + (_, generator_info) = self._get_table_and_generator() + return generator_info.columns if generator_info else [] + + def _column_metadata(self) -> list[Column]: + """Get the metadata for all the current columns.""" + table = self.table_metadata() + if table is None: + return [] + return [table.columns[name] for name in self._get_column_names()] + + def set_prompt(self) -> None: + """Set the prompt according to the current table, column and generator.""" + (table_name, gen_info) = self._get_table_and_generator() + if table_name is None: + self.prompt = "(generators) " + return + if gen_info is None: + self.prompt = f"({table_name}) " + return + table = self.table_metadata() + columns = [ + c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns + ] + gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" + self.prompt = f"({table_name}.{','.join(columns)}{gen}) " + + def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: + """ + Remove all automatic source stats. + + We assume every source stats query whose name begins with ``auto__` + :return: The new ``src_stats`` configuration. + """ + return self._remove_prefix_src_stats("auto__") + + def _copy_entries(self) -> None: + """Set generator and query information in the configuration.""" + src_stats = self._remove_auto_src_stats() + for entry in self.table_entries: + rgs = [] + new_gens: list[Generator] = [] + for generator in entry.new_generators: + if generator.gen is not None: + new_gens.append(generator.gen) + cqs = generator.gen.custom_queries() + for cq_key, cq in cqs.items(): + src_stats.append( + { + "name": cq_key, + "query": cq["query"], + "comments": [cq["comment"]] + if "comment" in cq and cq["comment"] + else [], + } + ) + rg: dict[str, Any] = { + "name": generator.gen.function_name(), + "columns_assigned": generator.columns, + } + kwn = generator.gen.nominal_kwargs() + if kwn: + rg["kwargs"] = kwn + rgs.append(rg) + aq = self._get_aggregate_query(new_gens, entry.name) + if aq: + src_stats.append( + { + "name": f"auto__{entry.name}", + "query": aq, + "comments": [ + q["comment"] + for gen in new_gens + for q in gen.select_aggregate_clauses().values() + if "comment" in q and q["comment"] is not None + ], + } + ) + table_config = self.get_table_config(entry.name) + if rgs: + table_config["row_generators"] = rgs + elif "row_generators" in table_config: + del table_config["row_generators"] + self.set_table_config(entry.name, table_config) + self.config["src-stats"] = src_stats + + def _find_old_generator( + self, entry: GeneratorCmdTableEntry, columns: Iterable[str] + ) -> Generator | None: + """Find any generator that previously assigned to these exact same columns.""" + fc = frozenset(columns) + for gen in entry.old_generators: + if frozenset(gen.columns) == fc: + return gen.gen + return None + + def do_quit(self, arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + header_shown = False + g_entry = cast(GeneratorCmdTableEntry, entry) + for gen in g_entry.new_generators: + old_gen = self._find_old_generator(g_entry, gen.columns) + new_gen = None if gen is None else gen.gen + if old_gen != new_gen: + if not header_shown: + header_shown = True + self.print("Table {0}:", entry.name) + count += 1 + self.print( + "...changing {0} from {1} to {2}", + ", ".join(gen.columns), + old_gen.name() if old_gen else "nothing", + gen.gen.name() if gen.gen else "nothing", + ) + if count == 0: + self.print("You have made no changes.") + if arg in {"yes", "no"}: + reply = arg + else: + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables.""" + for t_entry in self.table_entries: + entry = cast(GeneratorCmdTableEntry, t_entry) + gen_count = len(entry.new_generators) + how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" + self.print("{0} ({1})", entry.name, how_many) + + def do_list(self, _arg: str) -> None: + """List the generators in the current table.""" + if len(self.table_entries) <= self.table_index: + self.print("Error: no table {0}", self.table_index) + return + g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) + table = self.table_metadata() + for gen in g_entry.new_generators: + old_gen = self._find_old_generator(g_entry, gen.columns) + old = "" if old_gen is None else old_gen.name() + if old_gen == gen.gen: + becomes = "" + if old == "": + old = "(not set)" + elif gen.gen is None: + becomes = "(delete)" + else: + becomes = f"->{gen.gen.name()}" + primary = "" + if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: + primary = "[primary-key]" + self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) + + def do_columns(self, _arg: str) -> None: + """Report the column names and metadata.""" + self.report_columns() + + def do_info(self, _arg: str) -> None: + """Show information about the current column.""" + for cm in self._column_metadata(): + self.print( + "Column {0} in table {1} has type {2} ({3}).", + cm.name, + cm.table.name, + str(cm.type), + "nullable" if cm.nullable else "not nullable", + ) + if cm.primary_key: + self.print( + "It is a primary key, which usually does not" + " need a generator (it will auto-increment)" + ) + if cm.foreign_keys: + fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] + self.print( + "It is a foreign key referencing column {0}", ", ".join(fk_names) + ) + if len(fk_names) == 1 and not cm.primary_key: + self.print( + "You do not need a generator if you just want" + " a uniform choice over the referenced table's rows" + ) + + def _get_table_index(self, table_name: str) -> int | None: + """Get the index of the named table in the table entries list.""" + for n, entry in enumerate(self.table_entries): + if entry.name == table_name: + return n + return None + + def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + """ + Get the index number of a column within the list of generators in this table. + + :param table_index: The index of the table in which to search. + :param column_name: The name of the column to search for. + :return: The index in the ``new_generators`` attribute of the table entry + containing the specified column, or None if this does not exist. + """ + entry = self.table_entries[table_index] + for n, gen in enumerate(entry.new_generators): + if column_name in gen.columns: + return n + return None + + def go_to(self, target: str) -> bool: + """ + Go to a particular column. + + :return: True on success. + """ + parts = target.split(".", 1) + table_index = self._get_table_index(parts[0]) + if table_index is None: + if len(parts) == 1: + gen_index = self._get_generator_index(self.table_index, parts[0]) + if gen_index is not None: + self.generator_index = gen_index + self.set_prompt() + return True + self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, parts[0]) + return False + gen_index = None + if 1 < len(parts) and parts[1]: + gen_index = self._get_generator_index(table_index, parts[1]) + if gen_index is None: + self.print("we cannot set the generator for column {0}", parts[1]) + return False + self._set_table_index(table_index) + if gen_index is not None: + self.generator_index = gen_index + self.set_prompt() + return True + + def do_next(self, arg: str) -> None: + """ + Go to the next generator. or a specified generator. + + Go to a named table: 'next tablename', + go to a column: 'next tablename.columnname', + or go to a column within this table: 'next columnname'. + """ + if arg: + self.go_to(arg) + else: + self._go_next() + + def do_n(self, arg: str) -> None: + """Go to the next generator, or a specified generator.""" + self.do_next(arg) + + def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: + """Complete the ``n`` command's arguments.""" + return self.complete_next(text, line, begidx, endidx) + + def _go_next(self) -> None: + """Go to the next column.""" + table = self.get_table() + if table is None: + self.print("No more tables") + return + next_gi = self.generator_index + 1 + if next_gi == len(table.new_generators): + self.next_table(self.INFO_NO_MORE_TABLES) + return + self.generator_index = next_gi + self.set_prompt() + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Completions for the arguments of the ``next`` command.""" + parts = text.split(".", 1) + first_part = parts[0] + if 1 < len(parts): + column_name = parts[1] + table_index = self._get_table_index(first_part) + if table_index is None: + return [] + table_entry = self.table_entries[table_index] + return [ + f"{first_part}.{column}" + for gen in table_entry.new_generators + for column in gen.columns + if column.startswith(column_name) + ] + table_names = [ + entry.name + for entry in self.table_entries + if entry.name.startswith(first_part) + ] + if first_part in table_names: + table_names.append(f"{first_part}.") + current_table = self.get_table() + if current_table: + column_names = [ + col + for gen in current_table.new_generators + for col in gen.columns + if col.startswith(first_part) + ] + else: + column_names = [] + return table_names + column_names + + def do_previous(self, _arg: str) -> None: + """Go to the previous generator.""" + if self.generator_index == 0: + self._previous_table() + else: + self.generator_index -= 1 + self.set_prompt() + + def do_b(self, arg: str) -> None: + """Synonym for previous.""" + self.do_previous(arg) + + def _generators_valid(self) -> bool: + """Test if ``self.generators`` is still correct for the current columns.""" + return self.generators_valid_columns == ( + self.table_index, + self._get_column_names(), + ) + + def _get_generator_proposals(self) -> list[Generator]: + """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" + if not self._generators_valid(): + self.generators = None + if self.generators is None: + columns = self._column_metadata() + gens = everything_factory().get_generators(columns, self.sync_engine) + sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) + self.generators = sorted_gens + self.generators_valid_columns = ( + self.table_index, + self._get_column_names().copy(), + ) + return self.generators + + def _print_privacy(self) -> None: + """Print the privacy status of the current table.""" + table = self.table_metadata() + if table is None: + return + if table_is_private(self.config, table.name): + self.print(self.PRIMARY_PRIVATE_TEXT) + return + pfks = primary_private_fks(self.config, table) + if not pfks: + self.print(self.NOT_PRIVATE_TEXT) + return + self.print(self.SECONDARY_PRIVATE_TEXT, pfks) + + def do_compare(self, arg: str) -> None: + """ + Compare the real data with some generators. + + 'compare': just look at some source data from this column. + 'compare 5 6 10': compare a sample of the source data with a sample + from generators 5, 6 and 10. You can find out which numbers + correspond to which generators using the 'propose' command. + """ + self._print_privacy() + args = arg.split() + limit = 20 + comparison = { + "source": [ + x[0] if len(x) == 1 else ", ".join(x) + for x in self._get_column_data(limit, to_str=str) + ] + } + gens: list[Generator] = self._get_generator_proposals() + table_name = self.table_name() + for argument in args: + if argument.isdigit(): + n = int(argument) + if 0 < n <= len(gens): + gen = gens[n - 1] + comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) + self._print_values_queried(table_name, n, gen) + self.print_table_by_columns(comparison) + + def do_c(self, arg: str) -> None: + """Synonym for compare.""" + self.do_compare(arg) + + def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: + """ + Print the values queried from the database for this generator. + + :param table_name: The name of the table the generator applies to. + :param n: A number to print at the start of the output. + :param gen: The generator to report. + """ + if not gen.select_aggregate_clauses() and not gen.custom_queries(): + self.print( + "{0}. {1} requires no data from the source database.", + n, + gen.name(), + ) + else: + self.print( + "{0}. {1} requires the following data from the source database:", + n, + gen.name(), + ) + self._print_select_aggregate_query(table_name, gen) + self._print_custom_queries(gen) + + def _print_custom_queries(self, gen: Generator) -> None: + """ + Print all the custom queries and all the values they get in this case. + + :param gen: The generator to print the custom queries for. + """ + cqs = gen.custom_queries() + if not cqs: + return + cq_key2args: dict[str, Any] = {} + nominal = gen.nominal_kwargs() + actual = gen.actual_kwargs() + self._get_custom_queries_from( + cq_key2args, + nominal, + actual, + ) + for cq_key, cq in cqs.items(): + self.print( + "{0}; providing the following values: {1}", + cq["query"], + cq_key2args[cq_key], + ) + + def _get_custom_queries_from( + self, out: dict[str, Any], nominal: Any, actual: Any + ) -> None: + if isinstance(nominal, str): + src_stat_groups = self.SRC_STAT_RE.search(nominal) + # Do we have a SRC_STAT reference? + if src_stat_groups: + # Get its name + cq_key = src_stat_groups.group(1) + # Are we pulling a specific part of this result? + sub = src_stat_groups.group(3) + if sub: + actual = {sub: actual} + else: + out[cq_key] = actual + elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): + for i in range(min(len(nominal), len(actual))): + self._get_custom_queries_from(out, nominal[i], actual[i]) + elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): + for k, v in nominal.items(): + if k in actual: + self._get_custom_queries_from(out, v, actual[k]) + + def _get_aggregate_query( + self, gens: list[Generator], table_name: str + ) -> str | None: + clauses = [ + f'{q["clause"]} AS {n}' + for gen in gens + for n, q in or_default(gen.select_aggregate_clauses(), {}).items() + ] + if not clauses: + return None + return f"SELECT {', '.join(clauses)} FROM {table_name}" + + def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: + """ + Print the select aggregate query and all the values it gets in this case. + + This is not the entire query that will be executed, but only the part of it + that is required by a certain generator. + :param table_name: The table name. + :param gen: The generator to limit the aggregate query to. + """ + sacs = gen.select_aggregate_clauses() + if not sacs: + return + kwa = gen.actual_kwargs() + vals = [] + src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} + for n in sacs.keys(): + src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' + if src_stat in src_stat2kwarg: + ak = src_stat2kwarg[src_stat] + if ak in kwa: + vals.append(kwa[ak]) + else: + logger.warning( + "actual_kwargs for %s does not report %s", gen.name(), ak + ) + else: + logger.warning( + ( + "nominal_kwargs for %s does not have a value" + ' SRC_STATS["auto__%s"]["results"][0]["%s"]' + ), + gen.name(), + table_name, + n, + ) + select_q = self._get_aggregate_query([gen], table_name) + self.print("{0}; providing the following values: {1}", select_q, vals) + + def _get_column_data( + self, count: int, to_str: Callable[[Any], str] = repr + ) -> list[list[str]]: + columns = self._get_column_names() + columns_string = ", ".join(columns) + pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT {columns_string} FROM {self.table_name()}" + f" WHERE {pred} ORDER BY RANDOM() LIMIT {count}" + ) + ) + return [[to_str(x) for x in xs] for xs in result.all()] + + def do_propose(self, _arg: str) -> None: + """ + Display a list of possible generators for this column. + + They will be listed in order of fit, the most likely matches first. + The results can be compared (against a sample of the real data in + the column and against each other) with the 'compare' command. + """ + limit = 5 + gens = self._get_generator_proposals() + sample = self._get_column_data(limit) + if sample: + rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] + self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) + else: + self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) + if not gens: + self.print(self.PROPOSE_NOTHING) + for index, gen in enumerate(gens): + fit = gen.fit(-1) + if fit == -1: + fit_s = "(no fit)" + elif fit < 100: + fit_s = f"(fit: {fit:.3g})" + else: + fit_s = f"(fit: {fit:.0f})" + self.print( + self.PROPOSE_GENERATOR_SAMPLE_TEXT, + index=index + 1, + name=gen.name(), + fit=fit_s, + sample="; ".join(map(repr, gen.generate_data(limit))), + ) + + def do_p(self, arg: str) -> None: + """Synonym for propose.""" + self.do_propose(arg) + + def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: + """Find a generator by name from the list of proposals.""" + for gen in self._get_generator_proposals(): + if gen.name() == gen_name: + return gen + return None + + def do_set(self, arg: str) -> None: + """Set one of the proposals as a generator.""" + if arg.isdigit() and not self._generators_valid(): + self.print("Please run 'propose' before 'set '") + return + gens = self._get_generator_proposals() + new_gen: Generator | None + if arg.isdigit(): + index = int(arg) + if index < 1: + self.print("set's integer argument must be at least 1") + return + if len(gens) < index: + self.print( + "There are currently only {0} generators proposed, please select one of them.", + len(gens), + ) + return + new_gen = gens[index - 1] + else: + new_gen = self.get_proposed_generator_by_name(arg) + if new_gen is None: + self.print("'{0}' is not an appropriate generator for this column", arg) + return + self.set_generator(new_gen) + self._go_next() + + def set_generator(self, gen: Generator | None) -> None: + """Set the current column's generator.""" + (table, gen_info) = self._get_table_and_generator() + if table is None: + self.print("Error: no table") + return + if gen_info is None: + self.print("Error: no column") + return + gen_info.gen = gen + + def do_s(self, arg: str) -> None: + """Synonym for set.""" + self.do_set(arg) + + def do_unset(self, _arg: str) -> None: + """Remove any generator set for this column.""" + self.set_generator(None) + self._go_next() + + def merge_columns(self, arg: str) -> bool: + """ + Add this column(s) to the specified column(s). + + After this, one generator will cover them all. + """ + cols = arg.split() + if not cols: + self.print("Error: merge requires a column argument") + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + self.print(self.ERROR_NO_SUCH_TABLE) + return False + cols_available = functools.reduce( + lambda x, y: x | y, + [frozenset(gen.columns) for gen in table_entry.new_generators], + ) + cols_to_merge = frozenset(cols) + unknown_cols = cols_to_merge - cols_available + if unknown_cols: + for uc in unknown_cols: + self.print(self.ERROR_NO_SUCH_COLUMN, uc) + return False + gen_info = table_entry.new_generators[self.generator_index] + stated_current_columns = cols_to_merge & frozenset(gen_info.columns) + if stated_current_columns: + for c in stated_current_columns: + self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) + return False + # Remove cols_to_merge from each generator + new_new_generators: list[GeneratorInfo] = [] + for gen in table_entry.new_generators: + if gen is gen_info: + # Add columns to this generator + self.generator_index = len(new_new_generators) + new_new_generators.append( + GeneratorInfo( + columns=gen.columns + cols, + gen=None, + ) + ) + else: + # Remove columns if applicable + new_columns = [c for c in gen.columns if c not in cols_to_merge] + is_changed = len(new_columns) != len(gen.columns) + if new_columns: + # We have not removed this generator completely + new_new_generators.append( + GeneratorInfo( + columns=new_columns, + gen=None if is_changed else gen.gen, + ) + ) + table_entry.new_generators = new_new_generators + self.set_prompt() + return True + + def do_merge(self, arg: str) -> None: + """Add this column(s) to the specified column(s), so one generator covers them all.""" + self.merge_columns(arg) + + def complete_merge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Complete column names.""" + last_arg = text.split()[-1] + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return [] + return [ + column + for i, gen in enumerate(table_entry.new_generators) + if i != self.generator_index + for column in gen.columns + if column.startswith(last_arg) + ] + + def do_unmerge(self, arg: str) -> None: + """Remove this column(s) from this generator, make them a separate generator.""" + cols = arg.split() + if not cols: + self.print("Error: merge requires a column argument") + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + self.print(self.ERROR_NO_SUCH_TABLE) + return + gen_info = table_entry.new_generators[self.generator_index] + current_columns = frozenset(gen_info.columns) + cols_to_unmerge = frozenset(cols) + unknown_cols = cols_to_unmerge - current_columns + if unknown_cols: + for uc in unknown_cols: + self.print(self.ERROR_NO_SUCH_COLUMN, uc) + return + stated_unmerged_columns = cols_to_unmerge - current_columns + if stated_unmerged_columns: + for c in stated_unmerged_columns: + self.print(self.ERROR_COLUMN_ALREADY_UNMERGED, c) + return + if cols_to_unmerge == current_columns: + self.print(self.ERROR_CANNOT_UNMERGE_ALL) + return + # Remove unmerged columns + for um in cols_to_unmerge: + gen_info.columns.remove(um) + # The existing generator will not work + gen_info.gen = None + # And put them into a new (empty) generator + table_entry.new_generators.insert( + self.generator_index + 1, + GeneratorInfo( + columns=cols, + gen=None, + ), + ) + self.set_prompt() + + def complete_unmerge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Complete column names to unmerge.""" + last_arg = text.split()[-1] + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return [] + return [ + column + for column in table_entry.new_generators[self.generator_index].columns + if column.startswith(last_arg) + ] + + def get_current_columns(self) -> set[str]: + """Get the current colums.""" + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return set() + gen_info = table_entry.new_generators[self.generator_index] + return set(gen_info.columns) + + def set_merged_columns(self, first_col: str, other_cols: str) -> bool: + """ + Merge columns, after unmerging everything we don't want. + + :param first_col: The first column we want in the merge, must already + be in this column set. + :param other_cols: all the columns we want merged other than + first_col, in order, space-separated. + :return: True if the merge worked, false if there was an error + """ + existing = self.get_current_columns() + existing.discard(first_col) + for to_remove in existing: + self.do_unmerge(to_remove) + return self.merge_columns(other_cols) + + +def try_setting_generator(gc: GeneratorCmd, gens: Iterable[str]) -> bool: + """ + Set the current generator by name if possible. + + :param gc: The interactive ``GeneratorCmd`` to use. + :param gens: A list of names of generators to try, in order. + :return: True if one of the generators was successfully set, False otherwise. + """ + for gen in gens: + new_gen = gc.get_proposed_generator_by_name(gen) + if new_gen is not None: + gc.set_generator(new_gen) + return True + return False diff --git a/datafaker/interactive/missingness.py b/datafaker/interactive/missingness.py new file mode 100644 index 00000000..bd845c1d --- /dev/null +++ b/datafaker/interactive/missingness.py @@ -0,0 +1,356 @@ +"""Missingness configuration shell.""" +import re +from collections.abc import Iterable, Mapping, MutableMapping +from dataclasses import dataclass +from typing import cast + +from sqlalchemy import MetaData + +from datafaker.interactive.base import DbCmd, TableEntry + + +@dataclass +class MissingnessType: + """The functions required for applying missingness.""" + + SAMPLED = "column_presence.sampled" + SAMPLED_QUERY = ( + "SELECT COUNT(*) AS row_count, {result_names} FROM " + "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" + " AS __t GROUP BY {result_names}" + ) + name: str + query: str + comment: str + columns: list[str] + + @classmethod + def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: + """ + Construct a query to make a sampling of the named rows of the table. + + :param table: The name of the table to sample. + :param count: The number of samples to get. + :param column_names: The columns to fetch. + :return: The SQL query to do the sampling. + """ + result_names = ", ".join([f"{c}__is_null" for c in column_names]) + column_is_nulls = ", ".join( + [f"{c} IS NULL AS {c}__is_null" for c in column_names] + ) + return cls.SAMPLED_QUERY.format( + result_names=result_names, + column_is_nulls=column_is_nulls, + table=table, + count=count, + ) + + +@dataclass +class MissingnessCmdTableEntry(TableEntry): + """Table entry for the missingness command shell.""" + + old_type: MissingnessType + new_type: MissingnessType | None + + +class MissingnessCmd(DbCmd): + """ + Interactive shell for the user to set missingness. + + Can only be used for Missingness Completely At Random. + """ + + intro = "Interactive missingness configuration. Type ? for help.\n" + doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for +the current table. Use commands 'next' and 'previous' to change the +current table. Use 'tables' to list the tables and 'count' to show +how many NULLs exist in each column. Use 'peek' or 'select' to see +data from the database. Use 'quit' to exit this tool.""" + prompt = "(missingness) " + file = None + PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') + + def find_missingness_query( + self, missingness_generator: Mapping + ) -> tuple[str, str] | None: + """Find query and comment from src-stats for the passed missingness generator.""" + kwargs = missingness_generator.get("kwargs", {}) + patterns = kwargs.get("patterns", "") + pattern_match = self.PATTERN_RE.match(patterns) + if pattern_match: + key = pattern_match.group(1) + for src_stat in self.config["src-stats"]: + if src_stat.get("name") == key: + query = src_stat.get("query", None) + if not isinstance(query, str): + return None + return (query, src_stat.get("comment", "")) + return None + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> MissingnessCmdTableEntry | None: + """ + Make a table entry for a particular table. + + :param name: The name of the table to make an entry for. + :param table: The part of ``config.yaml`` relating to this table. + :return: The newly-constructed table entry. + """ + if table_config.get("ignore", False): + return None + if table_config.get("vocabulary_table", False): + return None + if table_config.get("num_rows_per_pass", 1) == 0: + return None + mgs = table_config.get("missingness_generators", []) + old = None + nonnull_columns = self.get_nonnull_columns(table_name) + if not nonnull_columns: + return None + if not mgs: + old = MissingnessType( + name="none", + query="", + comment="", + columns=[], + ) + elif len(mgs) == 1: + mg = mgs[0] + mg_name = mg.get("name", None) + if isinstance(mg_name, str): + query_comment = self.find_missingness_query(mg) + if query_comment is not None: + (query, comment) = query_comment + old = MissingnessType( + name=mg_name, + query=query, + comment=comment, + columns=mg.get("columns_assigned", []), + ) + if old is None: + return None + return MissingnessCmdTableEntry( + name=table_name, + old_type=old, + new_type=old, + ) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping, + ): + """ + Initialise a MissingnessCmd. + + :param src_dsn: connection string for the source database. + :param src_schema: schema name for the source database. + :param metadata: SQLAlchemy metadata for the source database. + :param config: Configuration from the ``config.yaml`` file. + """ + super().__init__(src_dsn, src_schema, metadata, config) + self.set_prompt() + + @property + def table_entries(self) -> list[MissingnessCmdTableEntry]: + """Get the table entries list.""" + return cast(list[MissingnessCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name( + self, table_name: str + ) -> MissingnessCmdTableEntry | None: + """Find the table entry given the table name.""" + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(MissingnessCmdTableEntry, entry) + + def set_prompt(self) -> None: + """Set the prompt according to the current table and missingness.""" + if self.table_index < len(self.table_entries): + entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + nt = entry.new_type + if nt is None: + self.prompt = f"(missingness for {entry.name}) " + else: + self.prompt = f"(missingness for {entry.name}: {nt.name}) " + else: + self.prompt = "(missingness) " + + def set_type(self, t_type: MissingnessType) -> None: + """Set the missingness of the current table.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + entry.new_type = t_type + + def _copy_entries(self) -> None: + """Set the new missingness into the configuration.""" + src_stats = self._remove_prefix_src_stats("missing_auto__") + for entry in self.table_entries: + table = self.get_table_config(entry.name) + if entry.new_type is None or entry.new_type.name == "none": + table.pop("missingness_generators", None) + else: + src_stat_key = f"missing_auto__{entry.name}__0" + table["missingness_generators"] = [ + { + "name": entry.new_type.name, + "kwargs": { + "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' + }, + "columns": entry.new_type.columns, + } + ] + src_stats.append( + { + "name": src_stat_key, + "query": entry.new_type.query, + "comments": [] + if entry.new_type.comment is None + else [entry.new_type.comment], + } + ) + self.set_table_config(entry.name, table) + + def do_quit(self, _arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + if entry.old_type != entry.new_type: + count += 1 + if entry.old_type is None: + self.print( + "Putting generator {0} on table {1}", + entry.name, + entry.new_type.name, + ) + elif entry.new_type is None: + self.print( + "Deleting generator {1} from table {0}", + entry.name, + entry.old_type.name, + ) + else: + self.print( + "Changing {0} from {1} to {2}", + entry.name, + entry.old_type.name, + entry.new_type.name, + ) + if count == 0: + self.print("You have made no changes.") + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables with their types.""" + for entry in self.table_entries: + old = "-" if entry.old_type is None else entry.old_type.name + new = "-" if entry.new_type is None else entry.new_type.name + desc = new if old == new else f"{old}->{new}" + self.print("{0} {1}", entry.name, desc) + + def do_next(self, arg: str) -> None: + """ + Go to the next table, or a specified table. + + 'next' = go to the next table, 'next tablename' = go to table 'tablename' + """ + if arg: + # Find the index of the table called _arg, if any + index = next( + (i for i, entry in enumerate(self.table_entries) if entry.name == arg), + None, + ) + if index is None: + self.print(self.ERROR_NO_SUCH_TABLE, arg) + return + self._set_table_index(index) + return + self.next_table(self.INFO_NO_MORE_TABLES) + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Get completions for tables and columns.""" + return [ + entry.name for entry in self.table_entries if entry.name.startswith(text) + ] + + def do_previous(self, _arg: str) -> None: + """Go to the previous table.""" + if not self._set_table_index(self.table_index - 1): + self.print(self.ERROR_ALREADY_AT_START) + + def _set_type(self, name: str, query: str, comment: str) -> None: + """Set the current table entry's query.""" + if len(self.table_entries) <= self.table_index: + return + entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + entry.new_type = MissingnessType( + name=name, + query=query, + comment=comment, + columns=self.get_nonnull_columns(entry.name), + ) + + def _set_none(self) -> None: + """Set the current table to have no missingness applied.""" + if len(self.table_entries) <= self.table_index: + return + self.table_entries[self.table_index].new_type = None + + def do_sampled(self, arg: str) -> None: + """ + Set the current table missingness as 'sampled', and go to the next table. + + 'sampled 3000' means sample 3000 rows at random and choose the + missingness to be the same as one of those 3000 at random. + 'sampled' means the same, but with a default number of rows sampled (1000). + """ + if len(self.table_entries) <= self.table_index: + self.print("Error! not on a table") + return + entry = self.table_entries[self.table_index] + if arg == "": + count = 1000 + elif arg.isdecimal(): + count = int(arg) + else: + self.print( + ( + "Error: sampled can be used alone or with" + " an integer argument. {0} is not permitted" + ), + arg, + ) + return + self._set_type( + MissingnessType.SAMPLED, + MissingnessType.sampled_query( + entry.name, + count, + self.get_nonnull_columns(entry.name), + ), + ( + "The missingness patterns and how often they appear in a" + f" sample of {count} from table {entry.name}" + ), + ) + self.print("Table {} set to sampled missingness", self.table_name()) + self.next_table() + + def do_none(self, _arg: str) -> None: + """Set the current table to have no missingness, and go to the next table.""" + self._set_none() + self.print("Table {} set to have no missingness", self.table_name()) + self.next_table() diff --git a/datafaker/interactive/table.py b/datafaker/interactive/table.py new file mode 100644 index 00000000..d763a14b --- /dev/null +++ b/datafaker/interactive/table.py @@ -0,0 +1,383 @@ +"""Table configuration command shell.""" +from collections.abc import Mapping, MutableMapping +from dataclasses import dataclass +from typing import Any, cast + +import sqlalchemy +from sqlalchemy import MetaData + +from datafaker.interactive.base import ( + TYPE_LETTER, + TYPE_PROMPT, + DbCmd, + TableEntry, + TableType, +) + + +@dataclass +class TableCmdTableEntry(TableEntry): + """Table entry for the table command shell.""" + + old_type: TableType + new_type: TableType + + +class TableCmd(DbCmd): + """Command shell allowing the user to set the type of each table.""" + + intro = ( + "Interactive table configuration (ignore," + " vocabulary, private, generate or empty). Type ? for help.\n" + ) + doc_leader = """Use the commands 'ignore', 'vocabulary', +'private', 'empty' or 'generate' to set the table's type. Use 'next' or +'previous' to change table. Use 'tables' and 'columns' for +information about the database. Use 'data', 'peek', 'select' or +'count' to see some data contained in the current table. Use 'quit' +to exit this program.""" + prompt = "(tableconf) " + file = None + WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( + "Vocabulary table {0} references non-vocabulary table {1}" + ) + WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( + "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + ) + WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" + WARNING_TEXT_POTENTIAL_PROBLEMS = ( + "NOTE: The following table types might cause problems later:" + ) + NOTE_TEXT_NO_CHANGES = "You have made no changes." + NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableCmdTableEntry | None: + """ + Make a table entry for the named table. + + :param name: The name of the table. + :param table: The part of ``config.yaml`` corresponding to this table. + :return: The newly-constructed table entry. + """ + if table_config.get("ignore", False): + return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) + if table_config.get("vocabulary_table", False): + return TableCmdTableEntry( + table_name, TableType.VOCABULARY, TableType.VOCABULARY + ) + if table_config.get("primary_private", False): + return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) + if table_config.get("num_rows_per_pass", 1) == 0: + return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) + return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ) -> None: + """Initialise a TableCmd.""" + super().__init__(src_dsn, src_schema, metadata, config) + self.set_prompt() + + @property + def table_entries(self) -> list[TableCmdTableEntry]: + """Get the list of table entries.""" + return cast(list[TableCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + """Get the table entry of the table with the given name.""" + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(TableCmdTableEntry, entry) + + def set_prompt(self) -> None: + """Set the prompt according to the current table and its type.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) + else: + self.prompt = "(table) " + + def set_type(self, t_type: TableType) -> None: + """Set the type of the current table.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + entry.new_type = t_type + + def _copy_entries(self) -> None: + """Alter the configuration to match the new table entries.""" + for entry in self.table_entries: + if entry.old_type != entry.new_type: + table = self.get_table_config(entry.name) + if ( + entry.old_type == TableType.EMPTY + and table.get("num_rows_per_pass", 1) == 0 + ): + table["num_rows_per_pass"] = 1 + if entry.new_type == TableType.IGNORE: + table["ignore"] = True + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + elif entry.new_type == TableType.VOCABULARY: + table.pop("ignore", None) + table["vocabulary_table"] = True + table.pop("primary_private", None) + elif entry.new_type == TableType.PRIVATE: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table["primary_private"] = True + elif entry.new_type == TableType.EMPTY: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + table["num_rows_per_pass"] = 0 + else: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + self.set_table_config(entry.name, table) + + def _get_referenced_tables(self, from_table_name: str) -> set[str]: + """Get all the tables referenced by this table's foreign keys.""" + from_meta = self.metadata.tables[from_table_name] + return { + fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys + } + + def _sanity_check_failures(self) -> list[tuple[str, str, str]]: + """Find tables that reference each other that should not given their types.""" + failures = [] + for from_entry in self.table_entries: + from_t = from_entry.new_type + if from_t == TableType.VOCABULARY: + referenced = self._get_referenced_tables(from_entry.name) + for ref in referenced: + to_entry = self._find_entry_by_table_name(ref) + if ( + to_entry is not None + and to_entry.new_type != TableType.VOCABULARY + ): + failures.append( + ( + self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + from_entry.name, + to_entry.name, + ) + ) + return failures + + def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: + """Find tables that reference each other that might cause problems given their types.""" + warnings = [] + for from_entry in self.table_entries: + from_t = from_entry.new_type + if from_t in {TableType.GENERATE, TableType.PRIVATE}: + referenced = self._get_referenced_tables(from_entry.name) + for ref in referenced: + to_entry = self._find_entry_by_table_name(ref) + if to_entry is not None and to_entry.new_type in { + TableType.EMPTY, + TableType.IGNORE, + }: + warnings.append( + ( + self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + from_entry.name, + to_entry.name, + ) + ) + return warnings + + def do_quit(self, _arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + if entry.old_type != entry.new_type: + count += 1 + self.print( + self.NOTE_TEXT_CHANGING, + entry.name, + entry.old_type.value, + entry.new_type.value, + ) + if count == 0: + self.print(self.NOTE_TEXT_NO_CHANGES) + failures = self._sanity_check_failures() + if failures: + self.print(self.WARNING_TEXT_PROBLEMS_EXIST) + for text, from_t, to_t in failures: + self.print(text, from_t, to_t) + warnings = self._sanity_check_warnings() + if warnings: + self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) + for text, from_t, to_t in warnings: + self.print(text, from_t, to_t) + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables with their types.""" + for entry in self.table_entries: + old = entry.old_type + new = entry.new_type + becomes = " " if old == new else "->" + TYPE_LETTER[new] + self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) + + def do_next(self, arg: str) -> None: + """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" + if arg: + # Find the index of the table called _arg, if any + index = self.find_entry_index_by_table_name(arg) + if index is None: + self.print(self.ERROR_NO_SUCH_TABLE, arg) + return + self._set_table_index(index) + return + self.next_table(self.INFO_NO_MORE_TABLES) + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Get the completions for tables and columns.""" + return [ + entry.name for entry in self.table_entries if entry.name.startswith(text) + ] + + def do_previous(self, _arg: str) -> None: + """Go to the previous table.""" + if not self._set_table_index(self.table_index - 1): + self.print(self.ERROR_ALREADY_AT_START) + + def do_ignore(self, _arg: str) -> None: + """Set the current table as ignored, and go to the next table.""" + self.set_type(TableType.IGNORE) + self.print("Table {} set as ignored", self.table_name()) + self.next_table() + + def do_vocabulary(self, _arg: str) -> None: + """Set the current table as a vocabulary table, and go to the next table.""" + self.set_type(TableType.VOCABULARY) + self.print("Table {} set to be a vocabulary table", self.table_name()) + self.next_table() + + def do_private(self, _arg: str) -> None: + """Set the current table as a primary private table (such as the table of patients).""" + self.set_type(TableType.PRIVATE) + self.print("Table {} set to be a primary private table", self.table_name()) + self.next_table() + + def do_generate(self, _arg: str) -> None: + """Set the current table as to be generated, and go to the next table.""" + self.set_type(TableType.GENERATE) + self.print("Table {} generate", self.table_name()) + self.next_table() + + def do_empty(self, _arg: str) -> None: + """Set the current table as empty; no generators will be run for it.""" + self.set_type(TableType.EMPTY) + self.print("Table {} empty", self.table_name()) + self.next_table() + + def do_columns(self, _arg: str) -> None: + """Report the column names and metadata.""" + self.report_columns() + + def do_data(self, arg: str) -> None: + """ + Report some data. + + 'data' = report a random ten lines, + 'data 20' = report a random 20 lines, + 'data 20 ColumnName' = report a random twenty entries from ColumnName, + 'data 20 ColumnName 30' = report a random twenty entries from + ColumnName of length at least 30, + """ + args = arg.split() + column = None + number = None + arg_index = 0 + min_length = 0 + table_metadata = self.table_metadata() + if arg_index < len(args) and args[arg_index].isdigit(): + number = int(args[arg_index]) + arg_index += 1 + if arg_index < len(args) and args[arg_index] in table_metadata.columns: + column = args[arg_index] + arg_index += 1 + if arg_index < len(args) and args[arg_index].isdigit(): + min_length = int(args[arg_index]) + arg_index += 1 + if arg_index != len(args): + self.print( + """Did not understand these arguments +The format is 'data [entries] [column-name [minimum-length]]' where [] means optional text. +Type 'columns' to find out valid column names for this table. +Type 'help data' for examples.""" + ) + return + if column is None: + if number is None: + number = 10 + self.print_row_data(number) + else: + if number is None: + number = 48 + self.print_column_data(column, number, min_length) + + def complete_data( + self, text: str, line: str, begidx: int, _endidx: int + ) -> list[str]: + """Get completions for arguments to ``data``.""" + previous_parts = line[: begidx - 1].split() + if len(previous_parts) != 2: + return [] + table_metadata = self.table_metadata() + return [k for k in table_metadata.columns.keys() if k.startswith(text)] + + def print_column_data(self, column: str, count: int, min_length: int) -> None: + """ + Print a sample of data from a certain column of the current table. + + :param column: The name of the column to report on. + :param count: The number of rows to sample. + :param min_length: The minimum length of text to choose from (0 for any text). + """ + where = f"WHERE {column} IS NOT NULL" + if 0 < min_length: + where = f"WHERE LENGTH({column}) >= {min_length}" + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT {column} FROM {self.table_name()}" + f" {where} ORDER BY RANDOM() LIMIT {count}" + ) + ) + self.columnize([str(x[0]) for x in result.all()]) + + def print_row_data(self, count: int) -> None: + """ + Print a sample or rows from the current table. + + :param count: The number of rows to report. + """ + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT * FROM {self.table_name()} ORDER BY RANDOM() LIMIT {count}" + ) + ) + if result is None: + self.print("No rows in this table!") + return + self.print_results(result) diff --git a/datafaker/main.py b/datafaker/main.py index 5b79831e..5a89d5e4 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,22 +1,24 @@ """Entrypoint for the datafaker package.""" import asyncio -from enum import Enum +import importlib +import io import json import sys -from importlib import metadata +from enum import Enum from pathlib import Path -from typing import Final, Optional +from typing import Any, Final, Optional import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate +from sqlalchemy import MetaData from typer import Argument, Exit, Option, Typer from datafaker.create import create_db_data, create_db_tables, create_db_vocab from datafaker.dump import dump_db_tables from datafaker.interactive import ( - update_config_tables, update_config_generators, + update_config_tables, update_missingness, ) from datafaker.make import ( @@ -68,38 +70,54 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name, config: dict | None=None): - with open(orm_file_name) as orm_fh: +def load_metadata_config( + orm_file_name: str, config: dict | None = None +) -> dict[str, Any]: + """ + Load the ``orm.yaml`` file, returning a dict representation. + + :param orm_file_name: The name of the file to load. + :param config: The ``config.yaml`` file object. Ignored tables will be + excluded from the output. + :return: A dict representing the ``orm.yaml`` file, with the tables + the ``config`` says to ignore removed. + """ + with open(orm_file_name, encoding="utf-8") as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) + if not isinstance(meta_dict, dict): + return {} tables_dict = meta_dict.get("tables", {}) if config is not None and "tables" in config: # Remove ignored tables - for (name, table_config) in config.get("tables", {}).items(): + for name, table_config in config.get("tables", {}).items(): if get_flag(table_config, "ignore"): tables_dict.pop(name, None) return meta_dict -def load_metadata(orm_file_name, config: dict | None=None): +def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: + """ + Load metadata from ``orm.yaml``. + + :param orm_file_name: ``orm.yaml`` or alternative name to load metadata from. + :param config: Used to exclude tables that are marked as ``ignore: true``. + :return: SQLAlchemy MetaData object representing the database described by the loaded file. + """ meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) -def load_metadata_for_output(orm_file_name, config: dict | None=None): - """ - Load metadata excluding any foreign keys pointing to ignored tables. - """ +def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> Any: + """Load metadata excluding any foreign keys pointing to ignored tables.""" meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, config) @app.callback() -def main(verbose: bool = Option( - False, - "--verbose", - "-v", - help="Print more information." -)): +def main( + verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") +) -> None: + """Set the global parameters.""" conf_logger(verbose) @@ -108,7 +126,7 @@ def create_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option( DF_FILENAME, - help="The name of the generators file. Must be in the current working directory." + help="The name of the generators file. Must be in the current working directory.", ), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), num_passes: int = Option(1, help="Number of passes (rows or stories) to make"), @@ -135,17 +153,17 @@ def create_data( config = read_config_file(config_file) if config_file is not None else {} orm_metadata = load_metadata_for_output(orm_file, config) df_module = import_file(df_file) - table_generator_dict = df_module.table_generator_dict - story_generator_list = df_module.story_generator_list try: row_counts = create_db_data( sorted_non_vocabulary_tables(orm_metadata, config), - table_generator_dict, - story_generator_list, + df_module, num_passes, + orm_metadata, ) logger.debug( - "Data created in %s %s.", num_passes, "pass" if num_passes == 1 else "passes" + "Data created in %s %s.", + num_passes, + "pass" if num_passes == 1 else "passes", ) for table_name, row_count in row_counts.items(): logger.debug( @@ -203,16 +221,18 @@ def create_tables( def create_generators( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option(DF_FILENAME, help="Path to write Python generators to."), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: Optional[str] = Option( None, help=( "Statistics file (output of make-stats); default is src-stats.yaml if the " "config file references SRC_STATS, or None otherwise." ), - show_default=False + show_default=False, + ), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing Python generators file." ), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing Python generators file."), ) -> None: """Make a datafaker file of generator classes. @@ -249,7 +269,12 @@ def create_generators( def make_vocab( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - force: bool = Option(False, "--force/--no-force", "-f/+f", help="Overwrite any existing vocabulary file."), + force: bool = Option( + False, + "--force/--no-force", + "-f/+f", + help="Overwrite any existing vocabulary file.", + ), compress: bool = Option(False, help="Compress file to .gz"), only: list[str] = Option([], help="Only download this table."), ) -> None: @@ -276,10 +301,11 @@ def make_vocab( @app.command() def make_stats( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: str = Option(STATS_FILENAME), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing vocabulary file."), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing vocabulary file." + ), ) -> None: """Compute summary statistics from the source database. @@ -295,13 +321,12 @@ def make_stats( _check_file_non_existence(stats_file_path) config = read_config_file(config_file) if config_file is not None else {} - orm_metadata = load_metadata(orm_file, config) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(src_dsn, config, orm_metadata, settings.src_schema) + make_src_stats(src_dsn, config, settings.src_schema) ) stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8") logger.debug("%s created.", stats_file) @@ -309,9 +334,17 @@ def make_stats( @app.command() def make_tables( - config_file: Optional[str] = Option(None, help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables"), + config_file: Optional[str] = Option( + None, + help=( + "The configuration file, used if you want" + " an orm.yaml lacking data for the ignored tables" + ), + ), orm_file: str = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing orm yaml file."), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing orm yaml file." + ), ) -> None: """Make a YAML file representing the tables in the schema. @@ -335,22 +368,26 @@ def make_tables( @app.command() def configure_tables( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path to write the configuration file to"), + config_file: str = Option( + CONFIG_FILENAME, help="Path to write the configuration file to" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), -): - """ - Interactively set tables to ignored, vocabulary or primary private. - """ +) -> None: + """Interactively set tables to ignored, vocabulary or primary private.""" logger.debug("Configuring tables in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) # we don't pass config here so that no tables are ignored metadata = load_metadata(orm_file) - config_updated = update_config_tables(src_dsn, settings.src_schema, metadata, config) + config_updated = update_config_tables( + src_dsn, settings.src_schema, metadata, config + ) if config_updated is None: logger.debug("Cancelled") return @@ -361,19 +398,23 @@ def configure_tables( @app.command() def configure_missing( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path to write the configuration file to"), + config_file: str = Option( + CONFIG_FILENAME, help="Path to write the configuration file to" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), -): - """ - Interactively set the missingness of the generated data. - """ +) -> None: + """Interactively set the missingness of the generated data.""" logger.debug("Configuring missingness in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) config_file_path = Path(config_file) - config = {} + config: dict[str, Any] = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config_any = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) + if isinstance(config_any, dict): + config = config_any metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) if config_updated is None: @@ -386,22 +427,32 @@ def configure_missing( @app.command() def configure_generators( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path of the configuration file to alter"), + config_file: str = Option( + CONFIG_FILENAME, help="Path of the configuration file to alter" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - spec: Path = Option(None, help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively") -): - """ - Interactively set generators for column data. - """ + spec: Path = Option( + None, + help=( + "CSV file (headerless) with fields table-name," + " column-name, generator-name to set non-interactively" + ), + ), +) -> None: + """Interactively set generators for column data.""" logger.debug("Configuring generators in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) metadata = load_metadata(orm_file, config) - config_updated = update_config_generators(src_dsn, settings.src_schema, metadata, config, spec_path=spec) + config_updated = update_config_generators( + src_dsn, settings.src_schema, metadata, config, spec_path=spec + ) if config_updated is None: logger.debug("Cancelled") return @@ -412,22 +463,25 @@ def configure_generators( @app.command() def dump_data( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path of the configuration file to alter"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path of the configuration file to alter" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: str = Argument(help="The table to dump"), output: str | None = Option(None, help="output CSV file name"), -): - """ Dump a whole table as a CSV file (or to the console) from the destination database. """ +) -> None: + """Dump a whole table as a CSV file (or to the console) from the destination database.""" settings = get_settings() dst_dsn: str = settings.dst_dsn or "" assert dst_dsn != "", "Missing DST_DSN setting." schema_name = settings.dst_schema config = read_config_file(config_file) if config_file is not None else {} metadata = load_metadata_for_output(orm_file, config) - if output == None: - dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) + if output is None: + if isinstance(sys.stdout, io.TextIOBase): + dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return - with open(output, 'wt', newline='') as out: + with open(output, "wt", newline="", encoding="utf-8") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) @@ -452,7 +506,9 @@ def validate_config( def remove_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Truncate non-vocabulary tables in the destination schema.""" if yes: @@ -469,7 +525,9 @@ def remove_data( def remove_vocab( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Truncate vocabulary tables in the destination schema.""" if yes: @@ -486,8 +544,15 @@ def remove_vocab( @app.command() def remove_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), + # pylint: disable=redefined-builtin + all: bool = Option( + False, + help="Don't use the ORM file, delete all tables in the destination schema", + ), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Drop all tables in the destination schema. @@ -495,25 +560,30 @@ def remove_tables( """ if yes: logger.debug("Dropping tables.") - config = read_config_file(config_file) if config_file is not None else {} - metadata = load_metadata_for_output(orm_file, config) - remove_db_tables(metadata) + if all: + remove_db_tables(None) + else: + config = read_config_file(config_file) + metadata = load_metadata_for_output(orm_file, config) + remove_db_tables(metadata) logger.debug("Tables dropped.") else: logger.info("Would remove tables if called with --yes.") class TableType(str, Enum): - all = "all" - vocab = "vocab" - generated = "generated" + """Types of tables for the ``list-tables`` command.""" + + ALL = "all" + VOCAB = "vocab" + GENERATED = "generated" @app.command() def list_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - tables: TableType = Option(TableType.generated, help="Which tables to list"), + tables: TableType = Option(TableType.GENERATED, help="Which tables to list"), ) -> None: """List the names of tables described in the metadata file.""" config = read_config_file(config_file) if config_file is not None else {} @@ -524,9 +594,9 @@ def list_tables( for (table_name, table_config) in config.get("tables", {}).items() if get_flag(table_config, "vocabulary_table") } - if tables == TableType.all: + if tables == TableType.ALL: names = all_table_names - elif tables == TableType.generated: + elif tables == TableType.GENERATED: names = all_table_names - vocab_table_names else: names = vocab_table_names @@ -540,7 +610,7 @@ def version() -> None: logger.info( "%s version %s", __package__, - metadata.version(__package__), + importlib.metadata.version(__package__), ) diff --git a/datafaker/make.py b/datafaker/make.py index 1284672f..6f4cc9bd 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -5,30 +5,34 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import ( - Any, Callable, Final, Mapping, Optional, Sequence, Tuple -) -import yaml +from types import TracebackType +from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple, Type import pandas as pd import snsql +import yaml from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider -from sqlalchemy import Engine, MetaData, UniqueConstraint, text +from sqlalchemy import CursorResult, Engine, MetaData, UniqueConstraint, text from sqlalchemy.dialects import postgresql -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table -from sqlalchemy.sql import sqltypes, type_api +from sqlalchemy.sql import Executable, sqltypes +from typing_extensions import Self from datafaker import providers from datafaker.settings import get_settings from datafaker.utils import ( + MaybeAsyncEngine, create_db_engine, download_table, - get_property, + get_columns_assigned, get_flag, + get_property, get_related_table_names, + get_row_generators, get_sync_engine, get_vocabulary_table_names, logger, @@ -73,7 +77,8 @@ class RowGeneratorInfo: @dataclass class ColumnChoice: - """ Chooses columns based on a random number in [0,1) """ + """Choose columns based on a random number in [0,1).""" + function_name: str argument_values: list[str] @@ -81,19 +86,36 @@ class ColumnChoice: def make_column_choices( table_config: Mapping[str, Any], ) -> list[ColumnChoice]: + """ + Convert ``missingness_generators`` from ``config.yaml`` into functions to call. + + :param table_config: The ``tables`` part of ``config.yaml``. + :return: A list of ``ColumnChoice`` objects; that is, descriptions of + functions and their arguments to call to reveal a list of columns that + should have values generated for them. + """ return [ ColumnChoice( function_name=mg["name"], - argument_values=[ - f"{k}={v}" - for k, v in mg.get("kwargs", {}).items() - ] + argument_values=[f"{k}={v}" for k, v in mg.get("kwargs", {}).items()], ) for mg in table_config.get("missingness_generators", []) if "name" in mg ] +@dataclass +class _PrimaryConstraint: + """ + Describes a Uniqueness constraint for a multi-column primary key. + + Not a real constraint, but enough to write df.py. + """ + + columns: list[Column] + name: str + + @dataclass class TableGeneratorInfo: """Contains the df.py content related to regular tables.""" @@ -104,7 +126,9 @@ class TableGeneratorInfo: column_choices: list[ColumnChoice] rows_per_pass: int row_gens: list[RowGeneratorInfo] = field(default_factory=list) - unique_constraints: list[UniqueConstraint] = field(default_factory=list) + unique_constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = field( + default_factory=list + ) @dataclass @@ -116,14 +140,16 @@ class StoryGeneratorInfo: num_stories_per_pass: int -def _render_value(v) -> str: - if type(v) is list: +def _render_value(v: Any) -> str: + if isinstance(v, list): return "[" + ", ".join(_render_value(x) for x in v) + "]" - if type(v) is set: + if isinstance(v, set): return "{" + ", ".join(_render_value(x) for x in v) + "}" - if type(v) is dict: - return "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" - if type(v) is str: + if isinstance(v, dict): + return ( + "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" + ) + if isinstance(v, str): return v return str(v) @@ -152,27 +178,15 @@ def _get_row_generator( ) -> tuple[list[RowGeneratorInfo], list[str]]: """Get the row generators information, for the given table.""" row_gen_info: list[RowGeneratorInfo] = [] - config: list[dict[str, Any]] = get_property(table_config, "row_generators", []) columns_covered = [] - for gen_conf in config: - name: str = gen_conf["name"] - columns_assigned = gen_conf["columns_assigned"] + for name, gen_conf in get_row_generators(table_config): + columns_assigned = list(get_columns_assigned(gen_conf)) keyword_arguments: Mapping[str, Any] = gen_conf.get("kwargs", {}) positional_arguments: Sequence[str] = gen_conf.get("args", []) - - if isinstance(columns_assigned, str): - columns_assigned = [columns_assigned] - - variable_names: list[str] = columns_assigned - try: - columns_covered += columns_assigned - except TypeError: - # Might be a single string, rather than a list of strings. - columns_covered.append(columns_assigned) - + columns_covered += columns_assigned row_gen_info.append( RowGeneratorInfo( - variable_names=variable_names, + variable_names=columns_assigned, function_call=_get_function_call( name, positional_arguments, keyword_arguments ), @@ -181,9 +195,7 @@ def _get_row_generator( return row_gen_info, columns_covered -def _get_default_generator( - column: Column -) -> RowGeneratorInfo: +def _get_default_generator(column: Column) -> RowGeneratorInfo: """Get default generator information, for the given column.""" # If it's a primary key column, we presume that primary keys are populated # automatically. @@ -215,7 +227,8 @@ def _get_default_generator( primary_key=column.primary_key, variable_names=variable_names, function_call=_get_function_call( - function_name=generator_function, positional_arguments=generator_arguments + function_name=generator_function, + positional_arguments=generator_arguments, ), ) @@ -223,52 +236,68 @@ def _get_default_generator( ( variable_names, generator_function, - generator_arguments, + generator_kwargs, ) = _get_provider_for_column(column) return RowGeneratorInfo( primary_key=column.primary_key, variable_names=variable_names, function_call=_get_function_call( - function_name=generator_function, keyword_arguments=generator_arguments + function_name=generator_function, keyword_arguments=generator_kwargs ), ) def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of a generator and maybe arguments - that limit its range to the permitted scale. + Get the default generator name and arguments. + + :param column: The column to get the generator for. + :return: The name of a generator and its arguments. """ column_type = column.type - if column_type.scale is None: + scale = getattr(column_type, "scale", None) + if scale is None: return ("generic.numeric.float_number", {}) - return ("generic.numeric.float_number", { - "start": 0, - "end": 10 ** column_type.scale - 1, - }) + return ( + "generic.numeric.float_number", + { + "start": "0", + "end": str(10**scale - 1), + }, + ) def _string_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of a string generator and maybe arguments - that limit its length. + Get the name of the default string generator for a column. + + :param column: The column to get the generator for. + :return: The name of the generator and its arguments. """ column_size: Optional[int] = getattr(column.type, "length", None) if column_size is None: return ("generic.text.color", {}) - return ("generic.person.password", { "length": str(column_size) }) + return ("generic.person.password", {"length": str(column_size)}) + def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of an integer generator. + Get the name of the default integer generator. + + :param column: The column to get the generator for. + :return: A pair consisting of the name of a generator and its + arguments. """ if not column.primary_key: return ("generic.numeric.integer_number", {}) - return ("generic.column_value_provider.increment", { - "db_connection": "dst_db_conn", - "column": f'metadata.tables["{column.table.name}"].columns["{column.name}"]', - }) + return ( + "generic.column_value_provider.increment", + { + "db_connection": "dst_db_conn", + "column": f'metadata.tables["{column.table.name}"].columns["{column.name}"]', + }, + ) _YEAR_SUMMARY_QUERY = ( @@ -279,8 +308,10 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: @dataclass class GeneratorInfo: + """Description of a generator.""" + # Name or function to generate random objects of this type (not using summary data) - generator: str | Callable[[Column], str] + generator: str | Callable[[Column], tuple[str, dict[str, str]]] # SQL query that gets the data to supply as arguments to the generator # ({column} and {table} will be interpolated) summary_query: str | None = None @@ -294,13 +325,19 @@ class GeneratorInfo: choice: bool = False -def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: +def get_result_mappings( + info: GeneratorInfo, results: CursorResult +) -> dict[str, Any] | None: """ - Gets a mapping from the results of a database query as a Python - dictionary converted according to the GeneratorInfo provided. + Get a mapping from the results of a database query. + + :return: A Python dictionary converted according to the GeneratorInfo provided. """ - kw = {} - for k, v in results.mappings().first().items(): + kw: dict[str, Any] = {} + mapping = results.mappings().first() + if mapping is None: + return kw + for k, v in mapping.items(): if v is None: return None conv_fn = info.arg_types.get(k, float) @@ -316,12 +353,12 @@ def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: sqltypes.Date: GeneratorInfo( generator="generic.datetime.date", summary_query=_YEAR_SUMMARY_QUERY, - arg_types={ "start": int, "end": int } + arg_types={"start": int, "end": int}, ), sqltypes.DateTime: GeneratorInfo( generator="generic.datetime.datetime", summary_query=_YEAR_SUMMARY_QUERY, - arg_types={ "start": int, "end": int } + arg_types={"start": int, "end": int}, ), sqltypes.Integer: GeneratorInfo( # must be before Numeric generator=_integer_generator, @@ -345,20 +382,20 @@ def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: sqltypes.String: GeneratorInfo( generator=_string_generator, choice=True, - ) + ), } def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: """ - Gets a generator from a column type. + Get a generator from a column type. Returns either a string representing the callable, or a callable that, given the column.type will return a tuple (string representing generator callable, dict of keyword arguments to pass to the callable). """ if column_t in _COLUMN_TYPE_TO_GENERATOR_INFO: - return _COLUMN_TYPE_TO_GENERATOR_INFO[column_t] + return _COLUMN_TYPE_TO_GENERATOR_INFO[column_t] # Search exhaustively for a superclass to the columns actual type for key, value in _COLUMN_TYPE_TO_GENERATOR_INFO.items(): @@ -368,10 +405,11 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: return None -def _get_generator_for_column(column_t: type) -> str | Callable[ - [type_api.TypeEngine], tuple[str, dict[str, str]]]: +def _get_generator_for_column( + column_t: type, +) -> str | Callable[[Column], tuple[str, dict[str, str]]] | None: """ - Gets a generator from a column type. + Get a generator from a column type. Returns either a string representing the callable, or a callable that, given the column.type will return a tuple (string representing generator @@ -381,10 +419,11 @@ def _get_generator_for_column(column_t: type) -> str | Callable[ return None if info is None else info.generator -def _get_generator_and_arguments(column: Column) -> tuple[str, dict[str, str]]: +def _get_generator_and_arguments(column: Column) -> tuple[str | None, dict[str, str]]: """ - Gets the generator and its arguments from the column type, returning - a tuple of a string representing the generator callable and a dict of + Get the generator and its arguments from the column type. + + :return: A tuple of a string representing the generator callable and a dict of keyword arguments to supply to it. """ generator_function = _get_generator_for_column(type(column.type)) @@ -392,7 +431,7 @@ def _get_generator_and_arguments(column: Column) -> tuple[str, dict[str, str]]: generator_arguments: dict[str, str] = {} if callable(generator_function): (generator_function, generator_arguments) = generator_function(column) - return generator_function,generator_arguments + return generator_function, generator_arguments def _get_provider_for_column(column: Column) -> Tuple[list[str], str, dict[str, str]]: @@ -437,17 +476,6 @@ def _constraint_sort_key(constraint: UniqueConstraint) -> str: ) -class _PrimaryConstraint: - """ - Describes a Uniqueness constraint for when multiple - columns in a table comprise the primary key. Not a - real constraint, but enough to write df.py. - """ - def __init__(self, *columns: Column, name: str): - self.name = name - self.columns = columns - - def _get_generator_for_table( table_config: Mapping[str, Any], table: Table, @@ -461,15 +489,13 @@ def _get_generator_for_table( ), key=_constraint_sort_key, ) - primary_keys = [ - c for c in table.columns - if c.primary_key - ] + primary_keys = [c for c in table.columns if c.primary_key] + constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = unique_constraints if 1 < len(primary_keys): - unique_constraints.append(_PrimaryConstraint( - *primary_keys, - name=f"{table.name}_primary_key" - )) + primary_constraint = _PrimaryConstraint( + columns=primary_keys, name=f"{table.name}_primary_key" + ) + constraints = unique_constraints + [primary_constraint] column_choices = make_column_choices(table_config) if column_choices: nonnull_columns = { @@ -485,7 +511,7 @@ def _get_generator_for_table( nonnull_columns=nonnull_columns, column_choices=column_choices, rows_per_pass=get_property(table_config, "num_rows_per_pass", 1), - unique_constraints=unique_constraints, + unique_constraints=constraints, ) row_gen_info_data, columns_covered = _get_row_generator(table_config) @@ -522,12 +548,9 @@ def make_vocabulary_tables( config: Mapping, overwrite_files: bool, compress: bool, - table_names: set[str] | None=None, -): - """ - Extracts the data from the source database for each - vocabulary table. - """ + table_names: set[str] | None = None, +) -> None: + """Extract the data from the source database for each vocabulary table.""" settings = get_settings() src_dsn: str = settings.src_dsn or "" assert src_dsn != "", "Missing SRC_DSN setting." @@ -539,7 +562,10 @@ def make_vocabulary_tables( else: invalid_names = table_names - vocab_names if invalid_names: - logger.error("The following names are not the names of vocabulary tables: %s", invalid_names) + logger.error( + "The following names are not the names of vocabulary tables: %s", + invalid_names, + ) logger.info("Valid names are: %s", vocab_names) return for table_name in table_names: @@ -567,8 +593,10 @@ def make_table_generators( # pylint: disable=too-many-locals Args: metadata: database ORM config: Configuration to control the generator creation. - orm_filename: "orm.yaml" file path so that the generator file can load the MetaData object - config_filename: "config.yaml" file path so that the generator file can load the MetaData object + orm_filename: "orm.yaml" file path so that the generator + file can load the MetaData object + config_filename: "config.yaml" file path so that the generator + file can load the MetaData object src_stats_filename: A filename for where to read src stats from. Optional, if `None` this feature will be skipped overwrite_files: Whether to overwrite pre-existing vocabulary files @@ -584,7 +612,7 @@ def make_table_generators( # pylint: disable=too-many-locals tables: list[TableGeneratorInfo] = [] vocabulary_tables: list[VocabularyTableGeneratorInfo] = [] vocab_names = get_vocabulary_table_names(config) - for (table_name, table) in metadata.tables.items(): + for table_name, table in metadata.tables.items(): if table_name in vocab_names: related = get_related_table_names(table) related_non_vocab = related.difference(vocab_names) @@ -593,16 +621,18 @@ def make_table_generators( # pylint: disable=too-many-locals "Making table '%s' a vocabulary table requires that also the" " related tables (%s) be also vocabulary tables.", table.name, - related_non_vocab + related_non_vocab, ) vocabulary_tables.append( _get_generator_for_existing_vocabulary_table(table) ) else: - tables.append(_get_generator_for_table( - tables_config.get(table.name, {}), - table, - )) + tables.append( + _get_generator_for_table( + tables_config.get(table.name, {}), + table, + ) + ) story_generators = _get_story_generators(config) @@ -639,9 +669,7 @@ def generate_df_content(template_context: Mapping[str, Any]) -> str: def _get_generator_for_existing_vocabulary_table( table: Table, ) -> VocabularyTableGeneratorInfo: - """ - Turns an existing vocabulary YAML file into a VocabularyTableGeneratorInfo. - """ + """Turn an existing vocabulary YAML file into a VocabularyTableGeneratorInfo.""" return VocabularyTableGeneratorInfo( dictionary_entry=table.name, variable_name=f"{table.name.lower()}_vocab", @@ -653,11 +681,9 @@ def _generate_vocabulary_table( table: Table, engine: Engine, overwrite_files: bool = False, - compress=False, -): - """ - Pulls data out of the source database to make a vocabulary YAML file - """ + compress: bool = False, +) -> None: + """Pull data out of the source database to make a vocabulary YAML file.""" yaml_file_name: str = table.fullname + ".yaml" if compress: yaml_file_name += ".gz" @@ -671,9 +697,7 @@ def _generate_vocabulary_table( def make_tables_file( db_dsn: str, schema_name: Optional[str], config: Mapping[str, Any] ) -> str: - """ - Construct the YAML file representing the schema. - """ + """Construct the YAML file representing the schema.""" tables_config = config.get("tables", {}) engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) @@ -705,33 +729,49 @@ def reflect_if(table_name: str, _: Any) -> bool: class DbConnection: - def __init__(self, engine): + """A connection to a database.""" + + def __init__(self, engine: MaybeAsyncEngine) -> None: + """ + Initialise an unopened database connection. + + Could be synchronous or asynchronous. + """ self._engine = engine + self._connection: Connection | AsyncConnection - async def __aenter__(self): + async def __aenter__(self) -> Self: + """Enter the ``with`` section, opening a connection.""" if isinstance(self._engine, AsyncEngine): self._connection = await self._engine.connect() else: self._connection = self._engine.connect() return self - async def __aexit__(self, _type, _value, _tb): - if isinstance(self._engine, AsyncEngine): + async def __aexit__( + self, + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: + """Exit the ``with`` section, closing the connection.""" + if isinstance(self._connection, AsyncConnection): await self._connection.close() else: self._connection.close() - async def execute_raw_query(self, query): - if isinstance(self._engine, AsyncEngine): + async def execute_raw_query(self, query: Executable) -> CursorResult: + """Execute the query on the owned connection.""" + if isinstance(self._connection, AsyncConnection): return await self._connection.execute(query) - else: - return self._connection.execute(query) + return self._connection.execute(query) - async def table_row_count(self, table_name: str): + async def table_row_count(self, table_name: str) -> int: + """Count the number of rows in the named table.""" with await self.execute_raw_query( text(f"SELECT COUNT(*) FROM {table_name}") ) as result: - return result.scalar_one() + return int(result.scalar_one()) async def execute_query(self, query_block: Mapping[str, Any]) -> Any: """Execute query in query_block.""" @@ -759,41 +799,48 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: return final_result -def fix_type(value): - if type(value) is decimal.Decimal: +def fix_type(value: Any) -> Any: + """Make this value suitable for yaml output.""" + if isinstance(value, decimal.Decimal): return float(value) return value -def fix_types(dics): - return [{ - k: fix_type(v) for k, v in dic.items() - } for dic in dics] +def fix_types(dics: list[dict]) -> list[dict]: + """Make all the items in this list suitable for yaml output.""" + return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] async def make_src_stats( - dsn: str, config: Mapping, metadata: MetaData, schema_name: Optional[str] = None -) -> dict[str, list[dict]]: - """Run the src-stats queries specified by the configuration. + dsn: str, config: Mapping, schema_name: Optional[str] = None +) -> dict[str, dict[str, Any]]: + """ + Run the src-stats queries specified by the configuration. Query the src database with the queries in the src-stats block of the `config` dictionary, using the differential privacy parameters set in the `smartnoise-sql` block of `config`. Record the results in a dictionary and return it. - Args: - dsn: database connection string - config: a dictionary with the necessary configuration - metadata: the database ORM - schema_name: name of the database schema - Returns: - The dictionary of src-stats. + :param dsn: database connection string + :param config: a dictionary with the necessary configuration + :param schema_name: name of the database schema + :return: The dictionary of src-stats. """ use_asyncio = config.get("use-asyncio", False) engine = create_db_engine(dsn, schema_name=schema_name, use_asyncio=use_asyncio) async with DbConnection(engine) as db_conn: - return await make_src_stats_connection(config, db_conn, metadata) + return await make_src_stats_connection(config, db_conn) -async def make_src_stats_connection(config: Mapping, db_conn: DbConnection, metadata: MetaData): + +async def make_src_stats_connection( + config: Mapping, db_conn: DbConnection +) -> dict[str, dict[str, Any]]: + """ + Make the ``src-stats.yaml`` file given the database connection to read from. + + :param config: configuration from ``config.yaml``. + :param db_conn: Source database connection. + """ date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/providers.py b/datafaker/providers.py index b07f2b7f..39a9d9ad 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -1,12 +1,18 @@ """This module contains Mimesis Provider sub-classes.""" import datetime as dt +import functools +import math import random -from typing import Any, Optional, Union, cast +from collections.abc import Mapping +from typing import Any, Callable, Generator, Optional, Union, cast +import numpy as np from mimesis import Datetime, Text from mimesis.providers.base import BaseDataProvider, BaseProvider -from sqlalchemy import Connection, Column -from sqlalchemy.sql import functions, select, func +from sqlalchemy import Column, Connection +from sqlalchemy.sql import func, functions, select + +from datafaker.utils import T, logger class ColumnValueProvider(BaseProvider): @@ -29,15 +35,16 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed = None, **kwargs): + def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: + """Initialise the column value provider.""" super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} def increment(self, db_connection: Connection, column: Column) -> int: - """ Return incrementing value for the column specified. """ + """Return incrementing value for the column specified.""" name = f"{column.table.name}.{column.name}" result = self.accumulators.get(name, None) - if result == None: + if result is None: row = db_connection.execute(select(func.max(column))).first() result = 0 if row is None or row[0] is None else row[0] value = result + 1 @@ -214,3 +221,367 @@ class Meta: def null() -> None: """Return `None`.""" return None + + +class InappropriateGeneratorException(Exception): + """Exception thrown if a generator is requested that is not appropriate.""" + + +class NothingToGenerateException(Exception): + """Exception thrown when no value can be generated.""" + + def __init__(self, message: str): + """Initialise the exception with a human-readable message.""" + super().__init__(message) + + +@functools.cache +def zipf_weights(size: int) -> list[float]: + """Get the weights of a Zipf distribution of a given size.""" + total = sum(map(lambda n: 1 / n, range(1, size + 1))) + return [1 / (n * total) for n in range(1, size + 1)] + + +def merge_with_constants( + xs: list[T], constants_at: dict[int, T] +) -> Generator[T, None, None]: + """ + Merge a list of items with other items that must be placed at certain indices. + + :param constants_at: A map of indices to objects that must be placed at + those indices. + :param xs: Items that fill in the gaps left by ``constants_at``. + :return: ``xs`` with ``constants_at`` inserted at the appropriate + points. If there are not enough elements in ``xs`` to fill in the gaps + in ``constants_at``, the elements of ``constants_at`` after the gap + are dropped. + """ + outi = 0 + xi = 0 + constant_count = len(constants_at) + while constant_count != 0: + if outi in constants_at: + yield constants_at[outi] + constant_count -= 1 + else: + if xi == len(xs): + return + yield xs[xi] + xi += 1 + outi += 1 + yield from xs[xi:] + + +class DistributionProvider(BaseProvider): + """A Mimesis provider for various distributions.""" + + class Meta: + """Meta-class for various distributions.""" + + name = "distribution_provider" + + root3 = math.sqrt(3) + + def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: + """Initialise a DistributionProvider.""" + super().__init__(seed=seed, **kwargs) + np_seed = seed if isinstance(seed, int) else None + self.np_gen = np.random.default_rng(seed=np_seed) + + def uniform(self, low: float, high: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param low: The lowest value that can be chosen. + :param high: The highest value that can be chosen. + :return: The output value. + """ + return random.uniform(float(low), float(high)) + + def uniform_ms(self, mean: float, sd: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ + m = float(mean) + h = self.root3 * float(sd) + return random.uniform(m - h, m + h) + + def normal(self, mean: float, sd: float) -> float: + """ + Choose a value according to a Gaussian (normal) distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ + return random.normalvariate(float(mean), float(sd)) + + def lognormal(self, logmean: float, logsd: float) -> float: + """ + Choose a value according to a lognormal distribution. + + :param logmean: The mean of the logs of the output values. + :param logsd: The standard deviation of the logs of the output values. + :return: The output value. + """ + return random.lognormvariate(float(logmean), float(logsd)) + + def choice_direct(self, a: list[T]) -> T: + """ + Choose a value with equal probability. + + :param a: The list of values to output. + :return: The chosen value. + """ + return random.choice(a) + + def choice(self, a: list[Mapping[str, T]]) -> T | None: + """ + Choose a value with equal probability. + + :param a: The list of values to output. Each element is a mapping with + a key ``value`` and the key is the value to return. + :return: The chosen value. + """ + return self.choice_direct(a).get("value", None) + + def zipf_choice_direct(self, a: list[T], n: int | None = None) -> T: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of values to output, most frequent first. + :return: The chosen value. + """ + if n is None: + n = len(a) + return random.choices(a, weights=zipf_weights(n))[0] + + def zipf_choice(self, a: list[Mapping[str, T]], n: int | None = None) -> T | None: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of rows to choose between, most frequent first. + Each element is a mapping with a key ``value`` and the key is the + value to return. + :return: The chosen value. + """ + c = self.zipf_choice_direct(a, n) + return c.get("value", None) + + def weighted_choice(self, a: list[dict[str, Any]]) -> Any: + """ + Choice weighted by the count in the original dataset. + + :param a: a list of dicts, each with a ``value`` key + holding the value to be returned and a ``count`` key holding the + number of that value found in the original dataset + :return: The chosen ``value``. + """ + vs = [] + counts = [] + for vc in a: + count = vc.get("count", 0) + if count: + counts.append(count) + vs.append(vc.get("value", None)) + c = random.choices(vs, weights=counts)[0] + return c + + def constant(self, value: T) -> T: + """Return the same value always.""" + return value + + def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: + """ + Return an array of values chosen from the given covariates. + + :param cov: Keys are ``rank``: The number of values to output; + ``mN``: The mean of variable ``N`` (where ``N`` is between 0 and + one less than ``rank``). ``cN_M`` (where 0 < ``N`` <= ``M`` < ``rank``): + the covariance between the ``N``th and the ``M``th variables. + :return: A numpy array of results. + """ + rank = int(cov["rank"]) + if rank == 0: + return np.empty(shape=(0,)) + mean = [float(cov[f"m{i}"]) for i in range(rank)] + covs = [ + [ + float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"]) + for i in range(rank) + ] + for j in range(rank) + ] + return self.np_gen.multivariate_normal(mean, covs) + + def _select_group(self, alts: list[dict[str, Any]]) -> Any: + """Choose one of the ``alts`` weighted by their ``"count"`` elements.""" + total = 0 + for alt in alts: + if alt["count"] < 0: + logger.warning( + "Alternative count is %d, but should not be negative", alt["count"] + ) + else: + total += alt["count"] + if total == 0: + raise NothingToGenerateException("No counts in any alternative") + choice = random.randrange(total) + for alt in alts: + choice -= alt["count"] + if choice < 0: + return alt + raise NothingToGenerateException( + "Internal error: ran out of choices in _select_group" + ) + + def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: + """ + Find all keys ``kN``, returning a dictionary of ``N: kNN``. + + This can be passed into ``merge_with_constants`` as the + ``constants_at`` argument. + """ + out: dict[int, Any] = {} + for k, v in result.items(): + if k.startswith("k") and k[1:].isnumeric(): + out[int(k[1:])] = v + return out + + PERMITTED_SUBGENS = { + "multivariate_lognormal", + "multivariate_normal", + "grouped_multivariate_lognormal", + "grouped_multivariate_normal", + "constant", + "weighted_choice", + "with_constants_at", + } + + def multivariate_normal(self, cov: dict[str, Any]) -> list[float]: + """ + Produce a list of values pulled from a multivariate distribution. + + :param cov: A dict with various keys: ``rank`` is the number of + output values, ``m0``, ``m1``, ... are the means of the + distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... + are the covariates, ``cN_M`` is the covariate of the ``N``th and + ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. + :return: list of ``rank`` floating point values + """ + out: list[float] = self.multivariate_normal_np(cov).tolist() + return out + + def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: + """ + Produce a list of values pulled from a multivariate distribution. + + :param cov: A dict with various keys: ``rank`` is the number of + output values, ``m0``, ``m1``, ... are the means of the + distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... + are the covariates, ``cN_M`` is the covariate of the ``N``th and + ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. These + are all the means and covariants of the logs of the data. + :return: list of ``rank`` floating point values + """ + out: list[Any] = np.exp(self.multivariate_normal_np(cov)).tolist() + return out + + def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: + """Produce a list of values pulled from a set of multivariate distributions.""" + cov = self._select_group(covs) + logger.debug("Multivariate normal group selected: %s", cov) + constants = self._find_constants(cov) + nums = self.multivariate_normal(cov) + return list(merge_with_constants(nums, constants)) + + def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: + """Produce a list of values pulled from a set of multivariate distributions.""" + cov = self._select_group(covs) + logger.debug("Multivariate lognormal group selected: %s", cov) + constants = self._find_constants(cov) + nums = np.exp(self.multivariate_normal_np(cov)).tolist() + return list(merge_with_constants(nums, constants)) + + def _check_generator_name(self, name: str) -> None: + if name not in self.PERMITTED_SUBGENS: + raise InappropriateGeneratorException( + f"{name} is not a permitted generator" + ) + + def alternatives( + self, + alternative_configs: list[dict[str, Any]], + counts: list[dict[str, int]] | None, + ) -> Any: + """ + Pick between other generators. + + :param alternative_configs: List of alternative generators. + Each alternative has the following keys: "count" -- a weight for + how often to use this alternative; "name" -- which generator + for this partition, for example "composite"; "params" -- the + parameters for this alternative. + :param counts: A list of weights for each alternative. If None, the + "count" value of each alternative is used. Each count is a dict + with a "count" key. + :return: list of values + """ + if counts is not None: + while True: + count = self._select_group(counts) + alt = alternative_configs[count["index"]] + name = alt["name"] + self._check_generator_name(name) + try: + return getattr(self, name)(**alt["params"]) + except NothingToGenerateException: + # Prevent this alternative from being chosen again + count["count"] = 0 + alt = self._select_group(alternative_configs) + name = alt["name"] + self._check_generator_name(name) + return getattr(self, name)(**alt["params"]) + + def with_constants_at( + self, constants_at: dict[int, T], subgen: str, params: dict[str, T] + ) -> list[T]: + """ + Insert constants into the results of a different generator. + + :param constants_at: A dictionary of positions and objects to insert + into the return list at those positions. + :param subgen: The name of the function to call to get the results + that will have the constants inserted into. + :param params: Keyword arguments to the ``subgen`` function. + :return: A list of results from calling ``subgen(**params)`` + with ``constants_at`` inserted in at the appropriate indices. + """ + if subgen not in self.PERMITTED_SUBGENS: + logger.error( + "subgenerator %s is not a valid name. Valid names are %s.", + subgen, + self.PERMITTED_SUBGENS, + ) + subout = getattr(self, subgen)(**params) + logger.debug("Merging constants %s", constants_at) + return list(merge_with_constants(subout, constants_at)) + + def truncated_string( + self, subgen_fn: Callable[..., list[T]], params: dict, length: int + ) -> list[T]: + """Call ``subgen_fn(**params)`` and truncate the results to ``length``.""" + result = subgen_fn(**params) + if result is None: + return None + return result[:length] diff --git a/datafaker/remove.py b/datafaker/remove.py index c0a6c47f..3924cdaf 100644 --- a/datafaker/remove.py +++ b/datafaker/remove.py @@ -1,7 +1,7 @@ """Functions and classes to undo the operations in create.py.""" from typing import Any, Mapping -from sqlalchemy import delete, MetaData +from sqlalchemy import MetaData, delete from datafaker.settings import get_settings from datafaker.utils import ( @@ -9,23 +9,18 @@ get_sync_engine, get_vocabulary_table_names, logger, - remove_vocab_foreign_key_constraints, reinstate_vocab_foreign_key_constraints, + remove_vocab_foreign_key_constraints, sorted_non_vocabulary_tables, ) -def remove_db_data( - metadata: MetaData, config: Mapping[str, Any] -) -> None: +def remove_db_data(metadata: MetaData, config: Mapping[str, Any]) -> None: """Truncate the synthetic data tables but not the vocabularies.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" remove_db_data_from( - metadata, - config, - settings.dst_dsn, - schema_name=settings.dst_schema + metadata, config, settings.dst_dsn, schema_name=settings.dst_schema ) @@ -33,9 +28,7 @@ def remove_db_data_from( metadata: MetaData, config: Mapping[str, Any], db_dsn: str, schema_name: str | None ) -> None: """Truncate the synthetic data tables but not the vocabularies.""" - dst_engine = get_sync_engine( - create_db_engine(db_dsn, schema_name=schema_name) - ) + dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) with dst_engine.connect() as dst_conn: for table in reversed(sorted_non_vocabulary_tables(metadata, config)): @@ -44,7 +37,9 @@ def remove_db_data_from( dst_conn.commit() -def remove_db_vocab(metadata: MetaData, meta_dict: Mapping[str, Any], config: Mapping[str, Any]) -> None: +def remove_db_vocab( + metadata: MetaData, meta_dict: Mapping[str, Any], config: Mapping[str, Any] +) -> None: """Truncate the vocabulary tables.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" @@ -61,11 +56,14 @@ def remove_db_vocab(metadata: MetaData, meta_dict: Mapping[str, Any], config: Ma reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_conn) -def remove_db_tables(metadata: MetaData) -> None: +def remove_db_tables(metadata: MetaData | None) -> None: """Drop the tables in the destination schema.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" dst_engine = get_sync_engine( create_db_engine(settings.dst_dsn, schema_name=settings.dst_schema) ) + if metadata is None: + metadata = MetaData() + metadata.reflect(dst_engine) metadata.drop_all(dst_engine) diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 51f4b038..b7f5cf2d 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,53 +1,71 @@ +"""Convert between a Python dict describing a database schema and a SQLAlchemy MetaData.""" +import typing +from functools import partial + import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql -from sqlalchemy.sql import sqltypes, schema -from typing import Callable +from sqlalchemy.sql import schema, sqltypes from datafaker.utils import make_foreign_key_name -table_component_t = dict[str, any] -table_t = dict[str, table_component_t] +TableT = dict[str, typing.Any] + + +# We will change this to parsy.Parser when parsy exports its types properly +ParserType = typing.Any + -def simple(type_): +def simple(type_: type) -> ParserType: """ - Parses a simple sqltypes type. + Get a parser for a simple sqltypes type. + For example, simple(sqltypes.UUID) takes the string "UUID" and outputs a UUID class, or fails with any other string. """ return parsy.string(type_.__name__).result(type_) -def integer(): - """ - Parses an integer, outputting that integer. - """ + +def integer() -> ParserType: + """Get a parser for an integer, outputting that integer.""" return parsy.regex(r"-?[0-9]+").map(int) -def integer_arguments(): + +def integer_arguments() -> ParserType: """ - Parses a list of integers. + Get a parser for a list of integers. + The integers are surrounded by brackets and separated by a comma and space. """ - return parsy.string("(") >> ( - integer().sep_by(parsy.string(", ")) - ) << parsy.string(")") + return ( + parsy.string("(") >> (integer().sep_by(parsy.string(", "))) << parsy.string(")") + ) -def numeric_type(type_): + +def numeric_type(type_: type) -> ParserType: """ + Make a parser for a SQL numeric type. + Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. """ - return parsy.string(type_.__name__ - ) >> integer_arguments().optional([]).combine(type_) + return parsy.string(type_.__name__) >> integer_arguments().optional([]).combine( + type_ + ) + + +def string_type(type_: type) -> ParserType: + """ + Make a parser for a SQL string type. + + Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" + or TYPE_NAME(32) COLLATE "fr" + """ -def string_type(type_): @parsy.generate(type_.__name__) - def st_parser(): - """ - Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" - or TYPE_NAME(32) COLLATE "fr" - """ + def st_parser() -> typing.Generator[ParserType, None, typing.Any]: + """Parse the specific type.""" yield parsy.string(type_.__name__) length: int | None = yield ( parsy.string("(") >> integer() << parsy.string(")") @@ -56,32 +74,48 @@ def st_parser(): parsy.string(' COLLATE "') >> parsy.regex(r'[^"]*') << parsy.string('"') ).optional() return type_(length=length, collation=collation) + return st_parser -def time_type(type_, pg_type): + +def time_type(type_: type, pg_type: type) -> ParserType: + """ + Make a parser for a SQL date/time type. + + Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE + or TYPE_NAME(32) WITH TIME ZONE + + :param type_: The SQLAlchemy type we would like to parse. + :param pg_type: The PostgreSQL type we would like to parse if precision + or timezone is provided. + :return: ``type_`` if neither precision nor timezone are provided in the + parsed text, ``pg_type(precision, timezone)`` otherwise. + """ + @parsy.generate(type_.__name__) - def pgt_parser(): - """ - Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE - or TYPE_NAME(32) WITH TIME ZONE - """ + def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: + """Parse the actual type.""" yield parsy.string(type_.__name__) precision: int | None = yield ( parsy.string("(") >> integer() << parsy.string(")") ).optional() timezone: str | None = yield ( - parsy.string(" WITH") >> ( - parsy.string(" ").result(True) | parsy.string("OUT ").result(False) - ) << parsy.string("TIME ZONE") + parsy.string(" WITH") + >> (parsy.string(" ").result(True) | parsy.string("OUT ").result(False)) + << parsy.string("TIME ZONE") ).optional(False) if precision is None and not timezone: # normal sql type return type_ return pg_type(precision=precision, timezone=timezone) + return pgt_parser + SIMPLE_TYPE_PARSER = parsy.alt( - parsy.string("DOUBLE PRECISION").result(sqltypes.DOUBLE_PRECISION), # must be before DOUBLE + parsy.string("DOUBLE PRECISION").result( + sqltypes.DOUBLE_PRECISION + ), # must be before DOUBLE simple(sqltypes.FLOAT), simple(sqltypes.DOUBLE), simple(sqltypes.INTEGER), @@ -110,15 +144,28 @@ def pgt_parser(): time_type(sqltypes.TIME, postgresql.types.TIME), ) + @parsy.generate -def type_parser(): +def type_parser() -> ParserType: + """ + Make a parser for a simple type or an array. + + Arrays produce a PostgreSQL-specific type. + """ base = yield SIMPLE_TYPE_PARSER dimensions = yield parsy.string("[]").many().map(len) if dimensions == 0: return base return postgresql.ARRAY(base, dimensions=dimensions) -def column_to_dict(column: Column, dialect: Dialect) -> str: + +def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: + """ + Produce a dict description of a column. + + :param column: The SQLAlchemy column to translate. + :param dialect: The SQL dialect in which to render the type name. + """ type_ = column.type if isinstance(type_, postgresql.DOMAIN): # Instead of creating a restricted type, we'll just use the base type. @@ -139,12 +186,28 @@ def column_to_dict(column: Column, dialect: Dialect) -> str: result["foreign_keys"] = foreign_keys return result + def dict_to_column( - table_name, - col_name, + table_name: str, + col_name: str, rep: dict, - ignore_fk: Callable[[str], bool], + ignore_fk: typing.Callable[[str], bool], ) -> Column: + """ + Produce column from aspects of its dict description. + + :param table_name: The name of the table the column appears in. + :param col_name: The name of the column. + :param rep: The dict description of the column. + :ignore_fk: A predicate, called with the name of any foreign key target + (in other words, the name of any table referred to by this column). If it + returns True, this foreign key constraint will not be applied to the + returned column. This is useful in a situation where we want a foreign + key constraint to be present when we are determining what generators + might be appropriate for it, but we don't want the foreign key constraint + actually applied to the destination database because (for example) the + target table will be ignored. + """ type_sql = rep["type"] try: type_ = type_parser.parse(type_sql) @@ -156,7 +219,7 @@ def dict_to_column( ForeignKey( fk, name=make_foreign_key_name(table_name, col_name), - ondelete='CASCADE', + ondelete="CASCADE", ) for fk in rep["foreign_keys"] if not ignore_fk(fk) @@ -171,26 +234,22 @@ def dict_to_column( nullable=rep.get("nullable", None), ) + def dict_to_unique(rep: dict) -> schema.UniqueConstraint: - return schema.UniqueConstraint( - *rep.get("columns", []), - name=rep.get("name", None) - ) + """Make a uniqueness constraint from its dict representation.""" + return schema.UniqueConstraint(*rep.get("columns", []), name=rep.get("name", None)) + def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: + """Render a dict representation of a uniqueness constraint.""" return { "name": constraint.name, - "columns": [ - str(col.name) - for col in constraint.columns - ] + "columns": [str(col.name) for col in constraint.columns], } -def table_to_dict(table: Table, dialect: Dialect) -> table_t: - """ - Converts a SQL Alchemy Table object into a - Python object ready for conversion to YAML. - """ + +def table_to_dict(table: Table, dialect: Dialect) -> TableT: + """Convert a SQL Alchemy Table object into a Python dict.""" return { "columns": { str(column.key): column_to_dict(column, dialect) @@ -203,27 +262,32 @@ def table_to_dict(table: Table, dialect: Dialect) -> table_t: ], } + def dict_to_table( name: str, meta: MetaData, - table_dict: table_t, - ignore_fk: Callable[[str], bool], + table_dict: TableT, + ignore_fk: typing.Callable[[str], bool], ) -> Table: + """Create a Table from its description.""" return Table( name, meta, - *[ dict_to_column(name, colname, col, ignore_fk) + *[ + dict_to_column(name, colname, col, ignore_fk) for (colname, col) in table_dict.get("columns", {}).items() ], - *[ dict_to_unique(constraint) - for constraint in table_dict.get("unique", []) - ], + *[dict_to_unique(constraint) for constraint in table_dict.get("unique", [])], ) -def metadata_to_dict(meta: MetaData, schema_name: str | None, engine: Engine) -> dict[str, table_t]: + +def metadata_to_dict( + meta: MetaData, schema_name: str | None, engine: Engine +) -> dict[str, typing.Any]: """ - Converts a SQL Alchemy MetaData object into - a Python object ready for conversion to YAML. + Convert a metadata object into a Python dict. + + The output will be ready for output to ``orm.yaml``. """ return { "tables": { @@ -235,25 +299,29 @@ def metadata_to_dict(meta: MetaData, schema_name: str | None, engine: Engine) -> } -def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): +def should_ignore_fk(tables_dict: dict[str, TableT], fk: str) -> bool: """ - Tell if this foreign key should be ignored because it points to an - ignored table. + Test if this foreign key points to an ignored table. + + If so, this foreign key should be ignored. + :param tables_dict: The ``tables`` value from ``config.yaml``. + :param fk: The name of the foreign key. """ fk_bits = fk.split(".", 2) if len(fk_bits) != 2: return True if fk_bits[0] not in tables_dict: return False - return tables_dict[fk_bits[0]].get("ignore", False) + return bool(tables_dict[fk_bits[0]].get("ignore", False)) + + +def _always_false(_: str) -> bool: + return False -def dict_to_metadata( - obj: dict, - config_for_output: dict=None -) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaData: """ - Converts a dict to a SQL Alchemy MetaData object. + Convert a dict to a SQL Alchemy MetaData object. :param config_for_output: The configuration object. Should be None if the metadata object is being used for connecting to the source database. @@ -262,12 +330,13 @@ def dict_to_metadata( constraint to an ignored table. """ tables_dict = obj.get("tables", {}) + ignore_fk: typing.Callable[[str], bool] if config_for_output and "tables" in config_for_output: tables_config = config_for_output["tables"] - ignore_fk = lambda fk: should_ignore_fk(fk, tables_config) + ignore_fk = partial(should_ignore_fk, tables_config) else: - ignore_fk = lambda _: False + ignore_fk = _always_false meta = MetaData() - for (k, td) in tables_dict.items(): + for k, td in tables_dict.items(): dict_to_table(k, meta, td, ignore_fk) return meta diff --git a/datafaker/templates/df.py.j2 b/datafaker/templates/df.py.j2 index 28c95827..0e2616e8 100644 --- a/datafaker/templates/df.py.j2 +++ b/datafaker/templates/df.py.j2 @@ -3,13 +3,13 @@ from mimesis import Generic, Numeric, Person from mimesis.locales import Locale import sqlalchemy import sys -from datafaker.base import FileUploader, TableGenerator, DistributionGenerator, ColumnPresence -from datafaker.main import load_metadata +from datafaker.base import FileUploader, TableGenerator, ColumnPresence +from datafaker.providers import DistributionProvider generic = Generic(locale=Locale.EN_GB) numeric = Numeric() person = Person() -dist_gen = DistributionGenerator() +dist_gen = DistributionProvider() column_presence = ColumnPresence() sys.path.append("") @@ -23,8 +23,6 @@ from datafaker.providers import ( generic.add_provider({{ provider_import }}) {% endfor %} -metadata = load_metadata("{{ orm_file_name }}", "{{ config_file_name }}") - {% if row_generator_module_name is not none %} import {{ row_generator_module_name }} {% endif %} @@ -44,10 +42,6 @@ with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f: SRC_STATS = yaml.unsafe_load(f) {% endif %} -{% for table_data in vocabulary_tables %} -{{ table_data.variable_name }} = FileUploader(metadata.tables["{{ table_data.table_name }}"]) -{% endfor %} - {% for table_data in tables %} class {{ table_data.class_name }}(TableGenerator): num_rows_per_pass = {{ table_data.rows_per_pass }} @@ -55,7 +49,7 @@ class {{ table_data.class_name }}(TableGenerator): def __init__(self): self.initialized = False - def __call__(self, dst_db_conn, get_random): + def __call__(self, dst_db_conn, metadata): if not self.initialized: {% for constraint in table_data.unique_constraints %} query_text = f"SELECT {% @@ -123,13 +117,6 @@ table_generator_dict = { {% endfor %} } - -vocab_dict = { -{% for table_data in vocabulary_tables %} - "{{ table_data.dictionary_entry }}": {{ table_data.variable_name }}, -{% endfor %} -} - {% for gen_data in story_generators %} def {{ gen_data.wrapper_name }}(dst_db_conn): return {{ gen_data.function_call.function_name }}({{ gen_data.function_call.argument_values| join(", ") }}) diff --git a/datafaker/utils.py b/datafaker/utils.py index 33b8c846..7ef91bff 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -1,27 +1,32 @@ """Utility functions.""" import ast +import gzip +import importlib.util +import io import json import logging import sys -import importlib.util +from collections.abc import Mapping, Sequence from pathlib import Path from types import ModuleType -from typing import Any, Final, Mapping, Optional, Union -import gzip +from typing import ( + Any, + Callable, + Final, + Generator, + Generic, + Iterable, + Optional, + TypeVar, + Union, +) +import psycopg2 +import sqlalchemy import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate -from psycopg2.errors import UndefinedObject -import sqlalchemy -from sqlalchemy import ( - Connection, - Engine, - ForeignKey, - create_engine, - event, - select, -) +from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.exc import IntegrityError, ProgrammingError from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -45,6 +50,18 @@ Path(__file__).parent / "json_schemas/config_schema.json" ) +T = TypeVar("T") + + +class Empty(Generic[T]): + """Generic empty sequences for default arguments.""" + + @classmethod + def iterable(cls) -> Iterable[T]: + """Get an empty iterable.""" + e: list[T] = [] + return (x for x in e) + def read_config_file(path: str) -> dict: """Read a config file, warning if it is invalid. @@ -81,25 +98,48 @@ def import_file(file_path: str) -> ModuleType: ModuleType """ spec = importlib.util.spec_from_file_location("df", file_path) + if spec is None or spec.loader is None: + raise ImportError(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module -def open_file(file_name): +def open_file(file_name: str | Path) -> io.BufferedWriter: + """Open a file for writing.""" return Path(file_name).open("wb") -def open_compressed_file(file_name): +def open_compressed_file(file_name: str | Path) -> gzip.GzipFile: + """ + Open a gzip-compressed file for writing. + + :param file_name: The name of the file to open. + :return: A file object; it can be written to as a normal uncompressed + file and it will do the compression. + """ return gzip.GzipFile(file_name, "wb") def table_row_count(table: Table, conn: Connection) -> int: + """ + Count the rows in the table. + + :param table: The table to count. + :param conn: The connection to the database. + :return: The number of rows in the table. + """ return conn.execute( - select(sqlalchemy.func.count()).select_from(sqlalchemy.table( - table.name, - *[sqlalchemy.column(col.name) for col in table.primary_key.columns.values()], - )) + # pylint: disable=not-callable + select(sqlalchemy.func.count()).select_from( + sqlalchemy.table( + table.name, + *[ + sqlalchemy.column(col.name) + for col in table.primary_key.columns.values() + ], + ) + ) ).scalar_one() @@ -117,10 +157,7 @@ def download_table( rowcount = table_row_count(table, conn) count = 0 for row in conn.execute(stmt).mappings(): - result = { - str(col_name): value - for (col_name, value) in row.items() - } + result = {str(col_name): value for (col_name, value) in row.items()} yamlfile.write(yaml.dump([result]).encode()) count += 1 if count % MAKE_VOCAB_PROGRESS_REPORT_EVERY == 0: @@ -128,7 +165,7 @@ def download_table( "written row %d of %d, %.1f%%", count, rowcount, - 100*count/rowcount, + 100 * count / rowcount, ) @@ -211,46 +248,54 @@ def warning_or_higher(record: logging.LogRecord) -> bool: class StdoutHandler(logging.Handler): """ A handler that writes to stdout. + We aren't using StreamHandler because that confuses typer.testing.CliRunner """ - def flush(self): + + def flush(self) -> None: + """Flush the buffer.""" self.acquire() try: sys.stdout.flush() finally: self.release() - def emit(self, record): + def emit(self, record: Any) -> None: + """Write the record.""" try: msg = self.format(record) sys.stdout.write(msg + "\n") sys.stdout.flush() except RecursionError: raise - except Exception: + except Exception: # pylint: disable=broad-exception-caught self.handleError(record) class StderrHandler(logging.Handler): """ A handler that writes to stderr. + We aren't using StreamHandler because that confuses typer.testing.CliRunner """ - def flush(self): + + def flush(self) -> None: + """Flush the buffer.""" self.acquire() try: sys.stderr.flush() finally: self.release() - def emit(self, record): + def emit(self, record: Any) -> None: + """Write the record.""" try: msg = self.format(record) sys.stderr.write(msg + "\n") sys.stderr.flush() except RecursionError: raise - except Exception: + except Exception: # pylint: disable=broad-exception-caught self.handleError(record) @@ -276,23 +321,41 @@ def conf_logger(verbose: bool) -> None: handlers=[stdout_handler, stderr_handler], force=True, ) - logging.getLogger('asyncio').setLevel(logging.WARNING) - logging.getLogger('blib2to3.pgen2.driver').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) -def get_flag(maybe_dict, key): - """Returns maybe_dict[key] or False if that doesn't exist""" - return type(maybe_dict) is dict and maybe_dict.get(key, False) +def get_flag(maybe_dict: Any, key: Any) -> bool: + """ + Get a boolean from a mapping, or False if that does not make sense. + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :return: True only if ``maybe_dict`` is a mapping, ``maybe_dict[key]`` + exists and ``maybe_dict[key]`` is truthy. + """ + return isinstance(maybe_dict, Mapping) and maybe_dict.get(key, False) + + +def get_property(maybe_dict: Any, key: Any, default: T) -> T: + """ + Get a specific property from a dict or a default if that does not exist. -def get_property(maybe_dict, key, default): - """Returns maybe_dict[key] or default if that doesn't exist""" - return maybe_dict.get(key, default) if type(maybe_dict) is dict else default + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :param default: The return value if ``maybe_dict`` is not a mapping, + or if ``key`` is not a key of ``maybe_dict``. + :return: ``maybe_dict[key]`` if this makes sense, or ``default`` if not. + """ + return maybe_dict.get(key, default) if isinstance(maybe_dict, Mapping) else default -def fk_refers_to_ignored_table(fk: ForeignKey): +def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: """ - Does this foreign key refer to a table that is configured as ignore in config.yaml + Test if this foreign key refers to an ignored table. + + :param fk: The foreign key to test. + :return: True if the table referred to is ignored in ``config.yaml``. """ try: fk.column @@ -301,9 +364,12 @@ def fk_refers_to_ignored_table(fk: ForeignKey): return False -def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint): +def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: """ - Does this foreign key constraint refer to a table that is configured as ignore in config.yaml + Test if the constraint refers to a table marked as ignored in ``config.yaml``. + + :param fk: The foreign key constraint. + :return: True if ``fk`` refers to an ignored table. """ try: fk.referred_table @@ -315,6 +381,10 @@ def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint): def get_related_table_names(table: Table) -> set[str]: """ Get the names of all tables for which there exist foreign keys from this table. + + :param table: SQLAlchemy table object. + :return: The set of the names of the tables referred to by foreign keys + in ``table``. """ return { str(fk.referred_table.name) @@ -325,22 +395,30 @@ def get_related_table_names(table: Table) -> set[str]: def table_is_private(config: Mapping, table_name: str) -> bool: """ - Return True if the table with name table_name is a primary private table - according to config. + Test if the named table is private. + + :param config: The ``config.yaml`` object. + :param table_name: The name of the table to test. + :return: True if the table is marked as private in ``config``. """ ts = config.get("tables", {}) - if type(ts) is not dict: + if not isinstance(ts, Mapping): return False t = ts.get(table_name, {}) - return t.get("primary_private", False) + ret = t.get("primary_private", False) + return ret if isinstance(ret, bool) else False def primary_private_fks(config: Mapping, table: Table) -> list[str]: """ - Returns the list of columns in the table that refer to primary private tables. + Get the list of columns in the table that refer to primary private tables. A table that is not primary private but has a non-empty list of primary_private_fks is secondary private. + + :param config: The ``config.yaml`` object. + :param table: The table to examine. + :return: A list of names of columns that refer to private tables. """ return [ str(fk.referred_table.name) @@ -351,9 +429,7 @@ def primary_private_fks(config: Mapping, table: Table) -> list[str]: def get_vocabulary_table_names(config: Mapping) -> set[str]: - """ - Extract the table names with a vocabulary_table: true property. - """ + """Extract the table names with a vocabulary_table: true property.""" return { table_name for (table_name, table_config) in config.get("tables", {}).items() @@ -361,16 +437,71 @@ def get_vocabulary_table_names(config: Mapping) -> set[str]: } +def get_columns_assigned( + row_generator_config: Mapping[str, Any] +) -> Generator[str, None, None]: + """ + Get the columns assigned in a ``row_generators[n]`` stanza. + + :param generator_config: The ``row_generators[n]`` stanza itself. + """ + ca = row_generator_config.get("columns_assigned", None) + if ca is None: + return + if isinstance(ca, str): + yield ca + return + if not hasattr(ca, "__iter__"): + return + for c in ca: + yield str(c) + + +def get_row_generators( + table_config: Mapping[str, Any], +) -> Generator[tuple[str, Mapping[str, Any]], None, None]: + """ + Get the row generators from a table configuration. + + :param table_config: The element from the ``tables:`` stanza of ``config.xml``. + :return: Pair of (name, row generator config). + """ + rgs = table_config.get("row_generators", None) + if isinstance(rgs, str) or not hasattr(rgs, "__iter__"): + return + for rg in rgs: + name = rg.get("name", None) + if name: + yield (name, rg) + + def make_foreign_key_name(table_name: str, col_name: str) -> str: + """Make a suitable foreign key name.""" return f"{table_name}_{col_name}_fkey" -def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): +def remove_vocab_foreign_key_constraints( + metadata: MetaData, + config: Mapping[str, Any], + dst_engine: Connection | Engine, +) -> None: + """ + Remove the foreign key constraints from vocabulary tables. + + This allows vocabulary tables to be loaded without worrying about + topologically sorting them or circular dependencies. + + :param metadata: The SQLAlchemy metadata from ``orm.yaml``. + :param config: The ``config.yaml`` object. + :param dst_engine: The destination database or a connection to it. + """ vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] for fk in vocab_table.foreign_key_constraints: - logger.debug("Dropping constraint %s from table %s", fk.name, vocab_table_name) + logger.debug( + "Dropping constraint %s from table %s", fk.name, vocab_table_name + ) with Session(dst_engine) as session: session.begin() try: @@ -378,21 +509,42 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): session.commit() except IntegrityError: session.rollback() - logger.exception("Dropping table %s key constraint %s failed:", vocab_table_name, fk.name) + logger.exception( + "Dropping table %s key constraint %s failed:", + vocab_table_name, + fk.name, + ) except ProgrammingError as e: session.rollback() - if type(e.orig) is UndefinedObject: + # pylint: disable=no-member + if isinstance(e.orig, psycopg2.errors.UndefinedObject): logger.debug("Constraint does not exist") else: raise e -def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_engine): +def reinstate_vocab_foreign_key_constraints( + metadata: MetaData, + meta_dict: Mapping[str, Any], + config: Mapping[str, Any], + dst_engine: Connection | Engine, +) -> None: + """ + Put the removed foreign keys back into the destination database. + + :param metadata: The SQLAlchemy metadata for the destination database. + :param meta_dict: The ``orm.yaml`` configuration that ``metadata`` was + created from. + :param config: The ``config.yaml`` data. + :param dst_engine: The connection to the destination database. + """ vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] try: - for (column_name, column_dict) in meta_dict["tables"][vocab_table_name]["columns"].items(): + for column_name, column_dict in meta_dict["tables"][vocab_table_name][ + "columns" + ].items(): fk_targets = column_dict.get("foreign_keys", []) if fk_targets: fk = ForeignKeyConstraint( @@ -400,17 +552,19 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng name=make_foreign_key_name(vocab_table_name, column_name), refcolumns=fk_targets, ) - logger.debug(f"Restoring foreign key constraint {fk.name}") + logger.debug("Restoring foreign key constraint %s", fk.name) with Session(dst_engine) as session: session.begin() vocab_table.append_constraint(fk) session.execute(AddConstraint(fk)) session.commit() except IntegrityError: - logger.exception("Restoring table %s foreign keys failed:", vocab_table_name) + logger.exception( + "Restoring table %s foreign keys failed:", vocab_table_name + ) -def stream_yaml(yaml_file_handle): +def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any, None, None]: """ Stream a yaml list into an iterator. @@ -424,7 +578,7 @@ def stream_yaml(yaml_file_handle): if not line or line.startswith("-"): if buf: yl = yaml.load(buf, yaml.Loader) - assert type(yl) is list and len(yl) == 1 + assert isinstance(yl, Sequence) and len(yl) == 1 yield yl[0] if not line: return @@ -432,23 +586,24 @@ def stream_yaml(yaml_file_handle): buf += line -def topological_sort(input_nodes, get_dependencies_fn): +def topological_sort( + input_nodes: Iterable[T], get_dependencies_fn: Callable[[T], set[T]] +) -> tuple[list[T], list[list[T]]]: """ Topoligically sort input_nodes and find any cycles. - Returns a pair (sorted, cycles). - - 'sorted' is a list of all the elements of input_nodes sorted + Returns a pair ``(sorted, cycles)``. + + ``sorted`` is a list of all the elements of input_nodes sorted so that dependencies returned by get_dependencies_fn come after nodes that depend on them. Cycles are arbitrarily broken for this. - 'cycles' is a list of lists of dependency cycles. + ``cycles`` is a list of lists of dependency cycles. - arguments: - input_nodes: an iterator of nodes to sort. Duplicates + :param input_nodes: an iterator of nodes to sort. Duplicates are discarded. - get_dependencies_fn: a function that takes an input + :param get_dependencies_fn: a function that takes an input node and returns a list of its dependencies. Any dependencies not in the input_nodes list are ignored. """ @@ -478,27 +633,53 @@ def topological_sort(input_nodes, get_dependencies_fn): elif n in grey: # n is in a cycle cycle_start = grey.index(n) - cycles.append(grey[cycle_start:len(grey)]) + cycles.append(grey[cycle_start : len(grey)]) return (black, cycles) def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Table]: - table_names = set( - metadata.tables.keys() - ).difference( + """ + Get the list of non-vocabulary tables, topologically sorted. + + :param metadata: SQLAlchemy database description. + :param config: The ``config.yaml`` object. + :return: The list of non-vocabulary tables, ordered such that the targets + of all the foreign keys come before their sources. + """ + table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) - (sorted, cycles) = topological_sort( - table_names, - lambda tn: get_related_table_names(metadata.tables[tn]) + (sorted_tables, cycles) = topological_sort( + table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: - logger.warning(f"Cycle detected between tables: {cycle}") - return [ metadata.tables[tn] for tn in sorted ] + logger.warning("Cycle detected between tables: %s", cycle) + return [metadata.tables[tn] for tn in sorted_tables] + + +def underline_error(e: SyntaxError) -> str: + r""" + Make an underline for this error. + + :return: string beginning ``\n`` then spaces then ``^^^^`` + underlining the error, or a null string if this was not possible. + """ + start = e.offset + if start is None: + return "" + end = e.end_offset + if end is None or end <= start: + end = start + 1 + return "\n" + " " * start + "^" * (end - start) + def generators_require_stats(config: Mapping) -> bool: """ - Returns true if any of the arguments for any of the generators reference SRC_STATS. + Test if the generator references ``SRC_STATS``. + + :param config: ``config.yaml`` object. + :return: True if any of the arguments for any of the generators + reference ``SRC_STATS``. """ ois = { f"object_instantiation.{k}": call @@ -522,38 +703,42 @@ def generators_require_stats(config: Mapping) -> bool: names = ( node.id for node in ast.walk(ast.parse(arg)) - if type(node) is ast.Name + if isinstance(node, ast.Name) ) if any(name == "SRC_STATS" for name in names): stats_required = True except SyntaxError as e: - errors.append(( - "Syntax error in argument %d of %s: %s\n%s\n%s", - n + 1, - where, - e.msg, - arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), - )) + errors.append( + ( + "Syntax error in argument %d of %s: %s\n%s%s", + n + 1, + where, + e.msg, + arg, + underline_error(e), + ) + ) for k, arg in call.get("kwargs", {}).items(): - if type(arg) is str: + if isinstance(arg, str): try: names = ( node.id for node in ast.walk(ast.parse(arg)) - if type(node) is ast.Name + if isinstance(node, ast.Name) ) if any(name == "SRC_STATS" for name in names): stats_required = True except SyntaxError as e: - errors.append(( - "Syntax error in argument %s of %s: %s\n%s\n%s", - k, - where, - e.msg, - arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), - )) + errors.append( + ( + "Syntax error in argument %s of %s: %s\n%s%s", + k, + where, + e.msg, + arg, + underline_error(e), + ) + ) for error in errors: logger.error(*error) return stats_required diff --git a/docs/source/_static/config_schema.html b/docs/source/_static/config_schema.html index ca0baa07..e78949f0 100644 --- a/docs/source/_static/config_schema.html +++ b/docs/source/_static/config_schema.html @@ -1 +1,2759 @@ - Datafaker Config

datafaker Config

Type: object

A datafaker configuration YAML file

No Additional Properties

Type: boolean

Run source-statistics queries using asyncpg.

Type: string

The name of a local Python module of row generators (excluding .py).

Type: string

The name of a local Python module of story generators (excluding .py).

Type: array

An array of source statistics queries.

Each item of this array must be:

Type: object
No Additional Properties

Type: string

A name for the query, which will be used in the stats file.

Type: string

A SQL query.

Type: string

A SmartNoise SQL query.

Type: number

The differential privacy epsilon value for the DP query.

Type: number

The differential privacy delta value for the DP query.

Type: object

See https://docs.smartnoise.org/sql/metadata.html#yaml-format.

All properties whose name matches the following regular expression must respect the following conditions

Property name regular expression: ^(?!(max_ids|row_privacy|sample_max_ids|censor_dims|clamp_counts|clamp_columns|use_dpsu)).*$
Type: object
No Additional Properties

Type: array of object

An array of story generators.

Each item of this array must be:

Type: object
No Additional Properties

Type: string

The full name of a story generator (e.g. mystorygenerators.short_story).

Type: array

Positional arguments to pass to the story generator.

Type: object

Keyword arguments to pass to the story generator.

Type: integer

The number of times to call the story generator per pass.

Type: integer

The maximum number of tries to respect a uniqueness constraint.

Type: object

Table configurations.

All properties whose name matches the following regular expression must respect the following conditions

Property name regular expression: .*
Type: object

A table configuration.

No Additional Properties

Type: boolean

Whether to completely ignore this table.

Type: boolean

Whether to export the table data.

Type: integer

The number of rows to generate per pass.

Type: array of object

An array of row generators to create column values.

Each item of this array must be:

Type: object

Type: string

The name of a (built-in or custom) function (e.g. max or myrowgenerators.my_gen).

Type: array

Positional arguments to pass to the function.

Type: object

Keyword arguments to pass to the function.

Type: array of string or string

One or more columns to assign the return value to.

Each item of this array must be:

\ No newline at end of file + + + + + + + + + + + + + + + + datafaker Config + + + +

datafaker Config

Type: object
+

A datafaker configuration YAML file

+
No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Run source-statistics queries using asyncpg.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a local Python module of row generators (excluding .py).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a local Python module of story generators (excluding .py).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Objects that need to be instantiated from the row and story generators modules.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

An array of source statistics queries.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A name for the query, which will be used in the stats file.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string
+

Comments to be copied into the src-stats.yaml file describing the query results.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A SQL query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A SmartNoise SQL query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: number
+

The differential privacy epsilon value for the DP query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: number
+

The differential privacy delta value for the DP query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

See https://docs.smartnoise.org/sql/metadata.html#yaml-format.

+
+ + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: integer
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+

+ +

+

All properties whose name matches the following regular expression must respect the following conditions

+ Property name regular expression: ^(?!(max_ids|row_privacy|sample_max_ids|censor_dims|clamp_counts|clamp_columns|use_dpsu)).*$ +
+ + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

An array of story generators.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The full name of a story generator (e.g. mystorygenerators.short_story).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

Positional arguments to pass to the story generator.

+
+ + + + + + No Additional Items +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the story generator.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The number of times to call the story generator per pass.

+
+ + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The maximum number of tries to respect a uniqueness constraint.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Table configurations.

+
+ + + + + + +
+
+
+

+ +

+
+ +
+

+ +

+

All properties whose name matches the following regular expression must respect the following conditions

+ Property name regular expression: .* +
+ + Type: object
+

A table configuration.

+
No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether to completely ignore this table.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether to export the table data.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether the table is a Primary Private table (perhaps a table of patients).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The number of rows to generate per pass.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

An array of row generators to create column values.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a (built-in or custom) function (e.g. max or myrowgenerators.my_gen).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

Positional arguments to pass to the function.

+
+ + + + + + No Additional Items +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the function.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string or string
+

One or more columns to assign the return value to.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

Function to generate a set of nullable columns that should not be null

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a (built-in or custom) function (e.g. column_presence.sampled).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the function.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string
+

Column names that might be returned.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + \ No newline at end of file diff --git a/docs/source/custom_generators.rst b/docs/source/custom_generators.rst index c29e340c..73d04b64 100644 --- a/docs/source/custom_generators.rst +++ b/docs/source/custom_generators.rst @@ -60,4 +60,4 @@ Again, you must define your own; ``datafaker`` provides no built-in story genera You can put your story generators in their own Python file, or you can re-use your row generators file if you like. -A story generator is a Python Generator function (a function that calls ``yield`` to return multiple values rather than ``return`` a single one). \ No newline at end of file +A story generator is a Python Generator function (a function that calls ``yield`` to return multiple values rather than ``return`` a single one). diff --git a/docs/source/docker.rst b/docs/source/docker.rst index 8954fc6e..62fcfe0a 100644 --- a/docs/source/docker.rst +++ b/docs/source/docker.rst @@ -32,7 +32,7 @@ computer. Running the image in this way will give you a command prompt from which datafaker can be called. Tab completion can be used. For example, if -you type ``sq mat`` you will see +you type ``dataf mat`` you will see ``datafaker make-tables``; although you might have to wait a second or two after some of the ```` key presses for the completed text to appear. Tab completion can also be used for command options such diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index c588c0bc..8cf833a8 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -124,7 +124,7 @@ Some of these functions take arguments, that we can assign like this: Anyway, we now need to remake the generators (``create-generators``) and re-run them (``create-data``): .. code-block:: console - + $ datafaker create-generators --force $ datafaker create-data --num-passes 15 diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 48e52a03..43722aa6 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -110,7 +110,7 @@ This command will start an interactive command shell. Don't be intimidated, just columns data help next peek private select vocabulary counts empty ignore generate previous quit tables - (table: myfirsttable) + (table: myfirsttable) You can also get help for any of the commands listed; for example to see help for the ``vocabulary`` command type ``? vocabulary`` or ``help vocabulary``: @@ -134,7 +134,7 @@ Press the Tab key again to see these options: .. code-block:: console (table: actor) help p - peek previous private + peek previous private (table: actor) help p Now you can continue with r-i-tab to get ``private``, r-e-tab to get ``previous`` or e-tab to get ``peek``. This can be very useful; try pressing Tab twice on an empty line to see quickly all the possible commands, for example! @@ -372,7 +372,7 @@ To describe "null-partitioned grouped", let us make the generator much more comp | None | None | None | Pencil on tracing paper | | None | 18.5 | 24.3 | Lithograph from an illustrated book of poems and four lithographs | +----------+---------------+---------------+------------------------------------------------------------------------------------------------+ - (artwork.depth_cm,width_cm,height_cm,medium) + (artwork.depth_cm,width_cm,height_cm,medium) Here we can see that Moma understandably does not record depths for 2D artworks so we have many NULLs in that column. If we try to apply the standard normal or lognormal to data with many NULLs, it will ignore those rows with any NULLs. diff --git a/mypy.ini b/mypy.ini index 86ff2fb3..c2ea784f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ # Global options: [mypy] -python_version = 3.9 +python_version = 3.10 disallow_untyped_defs = True disallow_any_unimported = True no_implicit_optional = True diff --git a/tests/test_base.py b/tests/test_base.py index 3f1e8cd4..411f1c09 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -46,7 +46,7 @@ def test_load(self) -> None: """Test the load method.""" vocab_gen = FileUploader(BaseTable.__table__) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: vocab_gen.load(conn) statement = select(BaseTable) rows = list(conn.execute(statement)) diff --git a/tests/test_create.py b/tests/test_create.py index 0fe1bf3e..933fbe26 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,33 +1,36 @@ """Tests for the create module.""" import itertools as itt -from collections import Counter import os -from pathlib import Path import random -from typing import Any, Generator, Tuple +from collections import Counter +from pathlib import Path +from typing import Any, Generator, Mapping, Tuple from unittest.mock import MagicMock, call, patch from sqlalchemy import Connection, select -from sqlalchemy.schema import Table +from sqlalchemy.schema import MetaData, Table from datafaker.base import TableGenerator -from datafaker.create import ( - create_db_vocab, - populate, -) +from datafaker.create import create_db_vocab, populate from datafaker.remove import remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict from tests.utils import DatafakerTestCase, GeneratesDBTestCase + class TestCreate(GeneratesDBTestCase): """Test the make_table_generators function.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def test_create_vocab(self) -> None: """Test the create_db_vocab function.""" - with patch.dict(os.environ, {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, clear=True): + with patch.dict( + os.environ, + {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, + clear=True, + ): config = { "tables": { "player": { @@ -36,11 +39,13 @@ def test_create_vocab(self) -> None: }, } self.set_configuration(config) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict( + self.metadata, self.schema_name, self.sync_engine + ) self.remove_data(config) remove_db_vocab(self.metadata, meta_dict, config) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = list(conn.execute(stmt).mappings().fetchall()) self.assertEqual(len(rows), 3) @@ -57,9 +62,9 @@ def test_create_vocab(self) -> None: def test_make_table_generators(self) -> None: """Test that we can handle column defaults in stories.""" random.seed(56) - config = {} + config: Mapping[str, Any] = {} self.generate_data(config, num_passes=2) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["string"]) rows = list(conn.execute(stmt).mappings().fetchall()) a = rows[0] @@ -83,8 +88,9 @@ def test_make_table_generators(self) -> None: class TestPopulate(DatafakerTestCase): - """ Test create.populate. """ + """Test create.populate.""" + # pylint: disable=too-many-locals def test_populate(self) -> None: """Test the populate function.""" table_name = "table_name" @@ -106,6 +112,7 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: mock_dst_conn.execute.return_value.returned_defaults = {} mock_table = MagicMock(spec=Table) mock_table.name = table_name + mock_metadata = MagicMock(spec=MetaData) mock_gen = MagicMock(spec=TableGenerator) mock_gen.num_rows_per_pass = num_rows_per_pass mock_gen.return_value = {} @@ -129,6 +136,7 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: [mock_table], {table_name: mock_gen}, story_generators, + mock_metadata, ) expected_row_count = ( @@ -160,13 +168,14 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None: mock_table_two.name = "two" mock_table_three = MagicMock(spec=Table) mock_table_three.name = "three" + mock_metadata = MagicMock(spec=MetaData) tables: list[Table] = [mock_table_one, mock_table_two, mock_table_three] row_generators: dict[str, TableGenerator] = { "two": mock_gen_two, "three": mock_gen_three, } - row_counts = populate(mock_dst_conn, tables, row_generators, []) + row_counts = populate(mock_dst_conn, tables, row_generators, [], mock_metadata) self.assertEqual(row_counts, {"two": 1, "three": 1}) self.assertListEqual( [call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list diff --git a/tests/test_dump.py b/tests/test_dump.py index 4293f285..2340a6cc 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -1,27 +1,33 @@ """Tests for the base module.""" -from sqlalchemy.schema import MetaData -from tests.utils import RequiresDBTestCase +import io from unittest.mock import MagicMock, call, patch +from sqlalchemy.schema import MetaData + from datafaker.dump import dump_db_tables +from tests.utils import RequiresDBTestCase + class DumpTests(RequiresDBTestCase): """Testing configure-tables.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: - """ Test dump-data. """ - TEST_OUTPUT_FILE = "test_output_file_object" + """Test dump-data.""" + test_output_file = io.StringIO() metadata = MetaData() - metadata.reflect(self.engine) - dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) - make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) - make_csv_writer.assert_has_calls([ - call().writerow(["id", "given_name", "family_name"]), - call().writerow((1, 'Mark', 'Samson')), - call().writerow((2, 'Tim', 'Friedman')), - call().writerow((3, 'Pierre', 'Marchmont')), - ]) + metadata.reflect(self.sync_engine) + dump_db_tables(metadata, self.dsn, self.schema_name, "player", test_output_file) + make_csv_writer.assert_called_once_with(test_output_file) + make_csv_writer.assert_has_calls( + [ + call().writerow(["id", "given_name", "family_name"]), + call().writerow((1, "Mark", "Samson")), + call().writerow((2, "Tim", "Friedman")), + call().writerow((3, "Pierre", "Marchmont")), + ] + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 00f45478..bfb2f096 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,24 +1,26 @@ """Tests for the CLI.""" import os import shutil +import tempfile from pathlib import Path +from typing import Any, Mapping from sqlalchemy import create_engine, inspect -from typer.testing import CliRunner - -from tests.utils import RequiresDBTestCase +from typer.testing import CliRunner, Result from datafaker.main import app +from tests.utils import RequiresDBTestCase # pylint: disable=subprocess-run-check + class DBFunctionalTestCase(RequiresDBTestCase): """End-to-end tests that require a database.""" + dump_file_path = "src.dump" database_name = "src" schema_name = "public" - test_dir = Path("tests/workspace") examples_dir = Path("tests/examples") orm_file_path = Path("orm.yaml") @@ -27,13 +29,10 @@ class DBFunctionalTestCase(RequiresDBTestCase): alt_orm_file_path = Path("my_orm.yaml") alt_datafaker_file_path = Path("my_df.py") - vocabulary_file_paths = tuple( - map(Path, ("concept.yaml", "concept_type.yaml", "mitigation_type.yaml")), - ) generator_file_paths = tuple( map(Path, ("story_generators.py", "row_generators.py")), ) - #dump_file_path = Path("dst.dump") + # dump_file_path = Path("dst.dump") config_file_path = Path("example_config2.yaml") stats_file_path = Path("example_stats.yaml") @@ -65,6 +64,7 @@ def setUp(self) -> None: ) # Copy some of the example files over to the workspace. + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) for file in self.generator_file_paths + (self.config_file_path,): src = self.examples_dir / file dst = self.test_dir / file @@ -79,6 +79,13 @@ def tearDown(self) -> None: os.chdir(self.start_dir) super().tearDown() + def assert_silent_success(self, completed_process: Result) -> None: + """Assert that the process completed successfully without producing output.""" + self.assertNoException(completed_process) + self.assertSuccess(completed_process) + self.assertEqual(completed_process.stderr, "") + self.assertEqual(completed_process.stdout, "") + def test_workflow_minimal_args(self) -> None: """Test the recommended CLI workflow runs without errors.""" shutil.copy(self.config_file_path, "config.yaml") @@ -86,26 +93,19 @@ def test_workflow_minimal_args(self) -> None: "make-tables", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stderr, "") - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-vocab", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-stats", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "create-generators", @@ -115,8 +115,18 @@ def test_workflow_minimal_args(self) -> None: self.assertNoException(completed_process) self.assertEqual( { - "Unsupported SQLAlchemy type CIDR for column column_with_unusual_type. Setting this column to NULL always, you may want to configure a row generator for it instead.", - "Unsupported SQLAlchemy type BIT for column column_with_unusual_type_and_length. Setting this column to NULL always, you may want to configure a row generator for it instead.", + ( + "Unsupported SQLAlchemy type CIDR for column " + "column_with_unusual_type. Setting this column to NULL " + "always, you may want to configure a row generator for " + "it instead." + ), + ( + "Unsupported SQLAlchemy type BIT for column " + "column_with_unusual_type_and_length. Setting this column " + "to NULL always, you may want to configure a row generator " + "for it instead." + ), }, set(completed_process.stderr.split("\n")) - {""}, ) @@ -126,27 +136,18 @@ def test_workflow_minimal_args(self) -> None: completed_process = self.invoke( "create-tables", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke( "create-vocab", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-stats", "--force", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke("create-data") self.assertNoException(completed_process) @@ -307,7 +308,10 @@ def test_workflow_maximal_args(self) -> None: self.assertSetEqual( { "Dropping constraint concept_concept_type_id_fkey from table concept", - "Dropping constraint ref_to_unignorable_table_ref_fkey from table ref_to_unignorable_table", + ( + "Dropping constraint ref_to_unignorable_table_ref_fkey from " + "table ref_to_unignorable_table" + ), "Dropping constraint concept_type_mitigation_type_id_fkey from table concept_type", "Restoring foreign key constraint concept_concept_type_id_fkey", "Restoring foreign key constraint ref_to_unignorable_table_ref_fkey", @@ -334,33 +338,50 @@ def test_workflow_maximal_args(self) -> None: ) self.assertEqual("", completed_process.stderr) self.assertEqual( - { - "Creating data.", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.full_row_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for table 'data_type_test'", - "Generating data for table 'no_pk_test'", - "Generating data for table 'person'", - "Generating data for table 'strange_type_table'", - "Generating data for table 'unique_constraint_test'", - "Generating data for table 'unique_constraint_test2'", - "Generating data for table 'test_entity'", - "Generating data for table 'hospital_visit'", - "Data created in 2 passes.", - f"person: {2*(3+1+2+2)} rows created.", - f"hospital_visit: {2*(2*2+3)} rows created.", - "data_type_test: 2 rows created.", - "no_pk_test: 2 rows created.", - "strange_type_table: 2 rows created.", - "unique_constraint_test: 2 rows created.", - "unique_constraint_test2: 2 rows created.", - "test_entity: 2 rows created.", - }, - set(completed_process.stdout.split("\n")) - {""}, + sorted( + [ + "Creating data.", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for table 'data_type_test'", + "Generating data for table 'data_type_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'person'", + "Generating data for table 'person'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'test_entity'", + "Generating data for table 'test_entity'", + "Generating data for table 'hospital_visit'", + "Generating data for table 'hospital_visit'", + "Data created in 2 passes.", + f"person: {2*(3+1+2+2)} rows created.", + f"hospital_visit: {2*(2*2+3)} rows created.", + "data_type_test: 2 rows created.", + "no_pk_test: 2 rows created.", + "strange_type_table: 2 rows created.", + "unique_constraint_test: 2 rows created.", + "unique_constraint_test2: 2 rows created.", + "test_entity: 2 rows created.", + "", + ] + ), + sorted(completed_process.stdout.split("\n")), ) completed_process = self.invoke( @@ -406,8 +427,14 @@ def test_workflow_maximal_args(self) -> None: 'Truncating vocabulary table "mitigation_type".', 'Truncating vocabulary table "empty_vocabulary".', "Vocabulary tables truncated.", - "Dropping constraint concept_type_mitigation_type_id_fkey from table concept_type", - "Dropping constraint ref_to_unignorable_table_ref_fkey from table ref_to_unignorable_table", + ( + "Dropping constraint concept_type_mitigation_type_id_fkey " + "from table concept_type" + ), + ( + "Dropping constraint ref_to_unignorable_table_ref_fkey from " + "table ref_to_unignorable_table" + ), "Dropping constraint concept_concept_type_id_fkey from table concept", "Restoring foreign key constraint concept_type_mitigation_type_id_fkey", "Restoring foreign key constraint ref_to_unignorable_table_ref_fkey", @@ -430,7 +457,21 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args, expected_error: str=None, env={}): + def invoke( + self, + *args: Any, + expected_error: str | None = None, + env: Mapping[str, str] | None = None, + ) -> Result: + """ + Run datafaker with the given arguments and environment. + + :param args: Arguments to provide to datafaker. + :param expected_error: If None, will assert that the invocation + passes successfully without throwing an exception. Otherwise, + the suggested error must be present in the standard error stream. + :param env: The environment variables to be set during invocation. + """ res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) @@ -458,11 +499,16 @@ def test_unique_constraint_fail(self) -> None: f"--orm-file={self.alt_orm_file_path}", "--force", ) + self.invoke( + "make-vocab", + f"--orm-file={self.alt_orm_file_path}", + f"--config-file={self.config_file_path}", + "--force", + ) self.invoke( "make-stats", f"--stats-file={self.stats_file_path}", f"--config-file={self.config_file_path}", - f"--orm-file={self.alt_orm_file_path}", "--force", ) self.invoke( @@ -513,12 +559,15 @@ def test_unique_constraint_fail(self) -> None: ) self.assertEqual("", completed_process.stderr) self.assertEqual( - ("Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.full_row_story'\n" - "Generating data for story 'story_generators.long_story'\n" - "Generating data for story 'story_generators.long_story'\n") * 3, + ( + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.full_row_story'\n" + "Generating data for story 'story_generators.long_story'\n" + "Generating data for story 'story_generators.long_story'\n" + ) + * 3, completed_process.stdout, ) @@ -529,7 +578,7 @@ def test_unique_constraint_fail(self) -> None: f"--orm-file={self.alt_orm_file_path}", f"--df-file={self.alt_datafaker_file_path}", "--num-passes=1", - expected_error = ( + expected_error=( "Failed to satisfy unique constraints for table unique_constraint_test" ), ) @@ -538,7 +587,7 @@ def test_unique_constraint_fail(self) -> None: def test_create_schema(self) -> None: """Check that we create a destination schema if it doesn't exist.""" - env = { "dst_schema": "doesntexistyetschema" } + env = {"dst_schema": "doesntexistyetschema"} engine = create_engine(self.env["dst_dsn"]) inspector = inspect(engine) diff --git a/tests/test_interactive.py b/tests/test_interactive.py deleted file mode 100644 index dc5cacc1..00000000 --- a/tests/test_interactive.py +++ /dev/null @@ -1,1657 +0,0 @@ -""" Tests for the base module. """ -import copy -from dataclasses import dataclass -import random -import re -from sqlalchemy import insert, select - -from datafaker.interactive import ( - DbCmd, - TableCmd, - GeneratorCmd, - MissingnessCmd, - update_config_generators, -) -from datafaker.generators import ( - NullPartitionedNormalGeneratorFactory, - ChoiceGeneratorFactory, -) - -from tests.utils import RequiresDBTestCase, GeneratesDBTestCase -from unittest.mock import MagicMock, Mock, patch - - -class TestDbCmdMixin(DbCmd): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.reset() - def reset(self): - self.messages: list[tuple[str, list, dict[str, any]]] = [] - self.headings: list[str] = [] - self.rows: list[list[str]] = [] - self.column_items: list[str] = [] - self.columns: dict[str, list[str]] = {} - def print(self, text: str, *args, **kwargs): - self.messages.append((text, args, kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]): - self.headings = headings - self.rows = rows - def print_table_by_columns(self, columns: dict[str, list[str]]): - self.columns = columns - def columnize(self, items: list[str]): - self.column_items.append(items) - def ask_save(self) -> str: - return "yes" - - -class TestTableCmd(TableCmd, TestDbCmdMixin): - """ TableCmd but mocked """ - - -class ConfigureTablesTests(RequiresDBTestCase): - """Testing configure-tables.""" - def _get_cmd(self, config) -> TestTableCmd: - return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) - - -class ConfigureTablesSrcTests(ConfigureTablesTests): - """Testing configure-tables with src.dump.""" - dump_file_path = "src.dump" - database_name = "src" - schema_name = "public" - - def test_table_name_prompts(self) -> None: - """Test that the prompts follow the names of the tables.""" - config = {} - with self._get_cmd(config) as tc: - table_names = list(self.metadata.tables.keys()) - for t in table_names: - self.assertIn(t, tc.prompt) - tc.do_next("") - self.assertListEqual(tc.messages, [(TableCmd.INFO_NO_MORE_TABLES, (), {})]) - tc.reset() - for t in reversed(table_names): - self.assertIn(t, tc.prompt) - tc.do_previous("") - self.assertListEqual(tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})]) - tc.reset() - bad_table_name = "notarealtable" - tc.do_next(bad_table_name) - self.assertListEqual(tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})]) - tc.reset() - good_table_name = table_names[2] - tc.do_next(good_table_name) - self.assertListEqual(tc.messages, []) - self.assertIn(good_table_name, tc.prompt) - - def test_column_display(self) -> None: - """Test that we can see the names of the columns.""" - config = {} - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_columns("") - self.assertListEqual( - tc.rows, - [ - ["id", "INTEGER", True, False, ""], - ["a", "BOOLEAN", False, False, ""], - ["b", "BOOLEAN", False, False, ""], - ["c", "TEXT", False, False, ""], - ], - ) - - def test_null_configuration(self) -> None: - """A table still works if its configuration is None.""" - config = { - "tables": None, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_private("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue(tables["unique_constraint_test"].get("primary_private", False)) - - def test_null_table_configuration(self) -> None: - """A table still works if its configuration is None.""" - config = { - "tables": { - "unique_constraint_test": None, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_private("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue(tables["unique_constraint_test"].get("primary_private", False)) - - def test_configure_tables(self) -> None: - """Test that we can change columns to ignore, vocab or generate.""" - config = { - "tables": { - "unique_constraint_test": { - "vocabulary_table": True, - }, - "no_pk_test": { - "ignore": True, - }, - "hospital_visit": { - "num_passes": 0, - }, - "empty_vocabulary": { - "private": True, - } - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_generate("") - tc.do_next("person") - tc.do_vocabulary("") - tc.do_next("mitigation_type") - tc.do_ignore("") - tc.do_next("hospital_visit") - tc.do_private("") - tc.do_quit("") - tc.do_next("empty_vocabulary") - tc.do_empty("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertFalse(tables["unique_constraint_test"].get("primary_private", False)) - self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) - self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) - self.assertTrue(tables["no_pk_test"].get("ignore", False)) - self.assertFalse(tables["no_pk_test"].get("primary_private", False)) - self.assertEqual(tables["no_pk_test"].get("num_rows_per_pass", 1), 1) - self.assertTrue(tables["person"].get("vocabulary_table", False)) - self.assertFalse(tables["person"].get("ignore", False)) - self.assertFalse(tables["person"].get("primary_private", False)) - self.assertEqual(tables["person"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["mitigation_type"].get("vocabulary_table", False)) - self.assertTrue(tables["mitigation_type"].get("ignore", False)) - self.assertFalse(tables["mitigation_type"].get("primary_private", False)) - self.assertEqual(tables["mitigation_type"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["hospital_visit"].get("vocabulary_table", False)) - self.assertFalse(tables["hospital_visit"].get("ignore", False)) - self.assertTrue(tables["hospital_visit"].get("primary_private", False)) - self.assertEqual(tables["hospital_visit"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["empty_vocabulary"].get("vocabulary_table", False)) - self.assertFalse(tables["empty_vocabulary"].get("ignore", False)) - self.assertFalse(tables["empty_vocabulary"].get("primary_private", False)) - self.assertEqual(tables["empty_vocabulary"].get("num_rows_per_pass", 1), 0) - - def test_print_data(self) -> None: - """Test that we can print random rows from the table and random data from columns.""" - person_table = self.metadata.tables["person"] - with self.engine.connect() as conn: - person_rows = conn.execute(select(person_table)).mappings().fetchall() - person_data = { - row["person_id"]: row - for row in person_rows - } - name_set = {row["name"] for row in person_rows} - person_headings = ["person_id", "name", "research_opt_out", "stored_from"] - with self._get_cmd({}) as tc: - tc.do_next("person") - tc.do_data("") - self.assertListEqual(tc.headings, person_headings) - self.assertEqual(len(tc.rows), 10) # default number of rows is 10 - for row in tc.rows: - expected = person_data[row[0]] - self.assertListEqual(row, [expected[h] for h in person_headings]) - tc.reset() - rows_to_get_count = 6 - tc.do_data(str(rows_to_get_count)) - self.assertListEqual(tc.headings, person_headings) - self.assertEqual(len(tc.rows), rows_to_get_count) - for row in tc.rows: - expected = person_data[row[0]] - self.assertListEqual(row, [expected[h] for h in person_headings]) - tc.reset() - to_get_count = 12 - tc.do_data(f"{to_get_count} name") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(len(tc.column_items[0]), to_get_count) - self.assertLessEqual(set(tc.column_items[0]), name_set) - tc.reset() - tc.do_data(f"{to_get_count} name 12") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(len(tc.column_items[0]), to_get_count) - tc.reset() - tc.do_data(f"{to_get_count} name 13") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set))) - tc.reset() - tc.do_data(f"{to_get_count} name 16") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set))) - - def test_list_tables(self): - """Test that we can list the tables""" - config = { - "tables": { - "unique_constraint_test": { - "vocabulary_table": True, - }, - "no_pk_test": { - "ignore": True, - }, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_ignore("") - tc.do_next("person") - tc.do_vocabulary("") - tc.reset() - tc.do_tables("") - person_listed = False - unique_constraint_test_listed = False - no_pk_test_listed = False - for (text, args, kwargs) in tc.messages: - if args[2] == "person": - self.assertFalse(person_listed) - person_listed = True - self.assertEqual(args[0], "G") - self.assertEqual(args[1], "->V") - elif args[2] == "unique_constraint_test": - self.assertFalse(unique_constraint_test_listed) - unique_constraint_test_listed = True - self.assertEqual(args[0], "V") - self.assertEqual(args[1], "->I") - elif args[2] == "no_pk_test": - self.assertFalse(no_pk_test_listed) - no_pk_test_listed = True - self.assertEqual(args[0], "I") - self.assertEqual(args[1], " ") - else: - self.assertEqual(args[0], "G") - self.assertEqual(args[1], " ") - self.assertTrue(person_listed) - self.assertTrue(unique_constraint_test_listed) - self.assertTrue(no_pk_test_listed) - - -class ConfigureTablesInstrumentsTests(ConfigureTablesTests): - """ Testing configure-tables with the instrument.sql database. """ - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def test_sanity_checks_both(self): - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with self._get_cmd(config) as tc: - tc.reset() - tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, ("model", "manufacturer"), {})) - self.assertEqual(tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {})) - self.assertEqual(tc.messages[4], (TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, ("signature_model", "player"), {})) - - def test_sanity_checks_warnings_only(self): - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: - tc.do_next("manufacturer") - tc.do_vocabulary("") - tc.reset() - tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_CHANGING, ("manufacturer", "ignore", "vocabulary"), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, ("signature_model", "player"), {})) - - def test_sanity_checks_errors_only(self): - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: - tc.do_next("signature_model") - tc.do_empty("") - tc.reset() - tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_CHANGING, ("signature_model", "generate", "empty"), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, ("model", "manufacturer"), {})) - - -class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): - """ GeneratorCmd but mocked """ - def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: - """ - Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples])""" - return { - kw["name"]: (kw["index"], kw["fit"], kw["sample"].split("; ")) - for (s, _, kw) in self.messages - if s == self.PROPOSE_GENERATOR_SAMPLE_TEXT - } - - -class ConfigureGeneratorsTests(RequiresDBTestCase): - """ Testing configure-generators. """ - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config) -> TestGeneratorCmd: - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_null_configuration(self): - """ Test that the tables having null configuration does not break. """ - config = { - "tables": None, - } - with self._get_cmd(config) as gc: - TABLE = "model" - gc.do_next(f"{TABLE}.name") - gc.do_propose("") - gc.do_compare("") - gc.do_set("1") - gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) - - def test_null_table_configuration(self): - """ Test that a table having null configuration does not break. """ - config = { - "tables": { - "model": None, - } - } - with self._get_cmd(config) as gc: - TABLE = "model" - gc.do_next(f"{TABLE}.name") - gc.do_propose("") - gc.do_set("1") - gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) - - def test_prompts(self) -> None: - """Test that the prompts follow the names of the columns and assigned generators.""" - config = {} - with self._get_cmd(config) as gc: - for table_name, table_meta in self.metadata.tables.items(): - for column_name, column_meta in table_meta.columns.items(): - self.assertIn(table_name, gc.prompt) - self.assertIn(column_name, gc.prompt) - if column_meta.primary_key: - self.assertIn("[pk]", gc.prompt) - else: - self.assertNotIn("[pk]", gc.prompt) - gc.do_next("") - self.assertListEqual(gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})]) - gc.reset() - for table_name, table_meta in reversed(list(self.metadata.tables.items())): - for column_name, column_meta in reversed(list(table_meta.columns.items())): - self.assertIn(table_name, gc.prompt) - self.assertIn(column_name, gc.prompt) - if column_meta.primary_key: - self.assertIn("[pk]", gc.prompt) - else: - self.assertNotIn("[pk]", gc.prompt) - gc.do_previous("") - self.assertListEqual(gc.messages, [(GeneratorCmd.ERROR_ALREADY_AT_START, (), {})]) - gc.reset() - bad_table_name = "notarealtable" - gc.do_next(bad_table_name) - self.assertListEqual(gc.messages, [( - GeneratorCmd.ERROR_NO_SUCH_TABLE_OR_COLUMN, - (bad_table_name,), - {} - )]) - gc.reset() - - def test_set_generator_mimesis(self): - """ Test that we can set one generator to a mimesis generator. """ - with self._get_cmd({}) as gc: - TABLE = "model" - COLUMN = "name" - GENERATOR = "person.first_name" - gc.do_next(f"{TABLE}.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[f"generic.{GENERATOR}"][0])) - gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) - self.assertDictEqual( - gc.config["tables"][TABLE]["row_generators"][0], - {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, - ) - - def test_set_generator_distribution(self): - """ Test that we can set one generator to gaussian. """ - with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.normal" - gc.do_next(f"{TABLE}.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) - gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] - self.assertEqual(len(row_gens), 1) - row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen["kwargs"], { - "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', - }) - self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") - self.assertEqual( - gc.config["src-stats"][0]["query"], - f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", - ) - - def test_set_generator_distribution_directly(self): - """ Test that we can set one generator to gaussian without going through propose. """ - with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.normal" - gc.do_next(f"{TABLE}.{COLUMN}") - gc.reset() - gc.do_set(GENERATOR) - self.assertListEqual(gc.messages, []) - gc.do_quit("") - self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") - self.assertEqual( - gc.config["src-stats"][0]["query"], - f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", - ) - - def test_set_generator_choice(self): - """ Test that we can set one generator to uniform choice. """ - with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.choice" - gc.do_next(f"{TABLE}.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) - gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] - self.assertEqual(len(row_gens), 1) - row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen["kwargs"], { - "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', - }) - self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}") - self.assertEqual( - gc.config["src-stats"][0]["query"], - f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", - ) - - def test_weighted_choice_generator_generates_choices(self): - """ Test that propose and compare show weighted_choice's values. """ - with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "position" - GENERATOR = "dist_gen.weighted_choice" - VALUES = {1, 2, 3, 4, 5, 6} - gc.do_next(f"{TABLE}.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gen_proposal = proposals[GENERATOR] - self.assertSubset(set(gen_proposal[2]), {str(v) for v in VALUES}) - gc.do_compare(str(gen_proposal[0])) - col_heading = f"{gen_proposal[0]}. {GENERATOR}" - self.assertIn(col_heading, gc.columns) - self.assertSubset(set(gc.columns[col_heading]), VALUES) - - def test_merge_columns(self): - """ Test that we can merge columns and set a multivariate generator """ - TABLE = "string" - COLUMN_1 = "frequency" - COLUMN_2 = "position" - GENERATOR_TO_DISCARD = "dist_gen.choice" - GENERATOR = "dist_gen.multivariate_normal" - with self._get_cmd({}) as gc: - gc.do_next(f"{TABLE}.{COLUMN_2}") - gc.do_propose("") - proposals = gc.get_proposals() - # set a generator, but this should not exist after merging - gc.do_set(str(proposals[GENERATOR_TO_DISCARD][0])) - gc.do_next(f"{TABLE}.{COLUMN_1}") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) - gc.do_propose("") - proposals = gc.get_proposals() - # set a generator, but this should not exist either - gc.do_set(str(proposals[GENERATOR_TO_DISCARD][0])) - gc.do_previous("") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) - gc.do_merge(COLUMN_2) - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) - gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] - self.assertEqual(len(row_gens), 1) - row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) - - def test_unmerge_columns(self): - """ Test that we can unmerge columns and generators are removed """ - TABLE = "string" - COLUMN_1 = "frequency" - COLUMN_2 = "position" - COLUMN_3 = "model_id" - REMAINING_GEN = "gen3" - config = { - "tables": { - TABLE: { - "row_generators": [ - {"name": "gen1", "columns_assigned": [COLUMN_1, COLUMN_2]}, - { "name": REMAINING_GEN, "columns_assigned": [COLUMN_3] }, - ] - } - } - } - with self._get_cmd(config) as gc: - gc.do_next(f"{TABLE}.{COLUMN_2}") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) - gc.do_unmerge(COLUMN_1) - self.assertIn(TABLE, gc.prompt) - self.assertNotIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) - # Next generator should be the unmerged one - gc.do_next("") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) - gc.do_quit("") - # Both generators should have disappeared - row_gens = gc.config["tables"][TABLE]["row_generators"] - self.assertEqual(len(row_gens), 1) - row_gen = row_gens[0] - self.assertEqual(row_gen["name"], REMAINING_GEN) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) - - def test_old_generators_remain(self): - """ Test that we can set one generator and keep an old one. """ - config = { - "tables": { - "string": { - "row_generators": [{ - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', - }, - }] - } - }, - "src-stats": [{ - "name": "auto__string", - "query": 'SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string', - }] - } - with self._get_cmd(config) as gc: - TABLE = "model" - COLUMN = "name" - GENERATOR = "person.first_name" - gc.do_next(f"{TABLE}.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[f"generic.{GENERATOR}"][0])) - gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) - self.assertDictEqual( - gc.config["tables"][TABLE]["row_generators"][0], - {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, - ) - row_gens = gc.config["tables"]["string"]["row_generators"] - self.assertEqual(len(row_gens), 1) - row_gen = row_gens[0] - self.assertEqual(row_gen["name"], "dist_gen.normal") - self.assertListEqual(row_gen["columns_assigned"], ["frequency"]) - self.assertDictEqual(row_gen["kwargs"], { - "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', - }) - self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__string") - self.assertEqual( - gc.config["src-stats"][0]["query"], - "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", - ) - - def test_aggregate_queries_merge(self): - """ - Test that we can set a generator that requires select aggregate clauses - and keep an old one, resulting in a merged query. - """ - config = { - "tables": { - "string": { - "row_generators": [{ - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }, - }] - } - }, - "src-stats": [{ - "name": "auto__string", - "query": 'SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string', - }] - } - with self._get_cmd(copy.deepcopy(config)) as gc: - COLUMN = "position" - GENERATOR = "dist_gen.uniform_ms" - gc.do_next(f"string.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[f"{GENERATOR}"][0])) - gc.do_quit("") - row_gens: list[dict[str,any]] = gc.config["tables"]["string"]["row_generators"] - self.assertEqual(len(row_gens), 2) - if row_gens[0]["name"] == GENERATOR: - row_gen0 = row_gens[0] - row_gen1 = row_gens[1] - else: - row_gen0 = row_gens[1] - row_gen1 = row_gens[0] - self.assertEqual(row_gen0["name"], GENERATOR) - self.assertEqual(row_gen1["name"], "dist_gen.normal") - self.assertListEqual(row_gen0["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen0["kwargs"], { - "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', - }) - self.assertListEqual(row_gen1["columns_assigned"], ["frequency"]) - self.assertDictEqual(row_gen1["kwargs"], { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }) - self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") - select_match = re.match(r'SELECT (.*) FROM string', gc.config["src-stats"][0]["query"]) - self.assertIsNotNone(select_match, "src_stats[0].query is not an aggregate select") - self.assertSetEqual(set(select_match.group(1).split(", ")), { - "AVG(frequency) AS mean__frequency", - "STDDEV(frequency) AS stddev__frequency", - f"AVG({COLUMN}) AS mean__{COLUMN}", - f"STDDEV({COLUMN}) AS stddev__{COLUMN}", - }) - - def test_next_completion(self): - """ Test tab completion for the next command. """ - with self._get_cmd({}) as gc: - self.assertSetEqual( - set(gc.complete_next("m", "next m", 5, 6)), - {"manufacturer", "model"}, - ) - self.assertSetEqual( - set(gc.complete_next("model", "next model", 5, 10)), - {"model", "model."}, - ) - self.assertSetEqual( - set(gc.complete_next("string.", "next string.", 5, 11)), - {"string.id", "string.model_id", "string.position", "string.frequency"}, - ) - self.assertSetEqual( - set(gc.complete_next("string.p", "next string.p", 5, 12)), - {"string.position"}, - ) - self.assertListEqual(gc.complete_next("string.q", "next string.q", 5, 12), []) - self.assertListEqual(gc.complete_next("ww", "next ww", 5, 7), []) - - def test_compare_reports_privacy(self): - """ - Test that compare reports whether the current table is primary private, - secondary private or not private. - """ - config = { - "tables": { - "model": { - "primary_private": True, - } - }, - } - with self._get_cmd(config) as gc: - gc.do_next("manufacturer") - gc.reset() - gc.do_compare("") - (text, args, _kwargs) = gc.messages[0] - self.assertEqual(text, gc.NOT_PRIVATE_TEXT) - gc.do_next("model") - gc.reset() - gc.do_compare("") - (text, args, _kwargs) = gc.messages[0] - self.assertEqual(text, gc.PRIMARY_PRIVATE_TEXT) - gc.do_next("string") - gc.reset() - gc.do_compare("") - (text, args, _kwargs) = gc.messages[0] - self.assertEqual(text, gc.SECONDARY_PRIVATE_TEXT) - self.assertSequenceEqual(args, [["model"]]) - - def test_existing_configuration_remains(self): - """ - Test setting a generator does not remove other information. - """ - config = { - "tables": { - "string": { - "primary_private": True, - } - }, - "src-stats": [{ - "name": "kraken", - "query": 'SELECT MAX(frequency) AS max_frequency FROM string', - }] - } - with self._get_cmd(config) as gc: - COLUMN = "position" - GENERATOR = "dist_gen.uniform_ms" - gc.do_next(f"string.{COLUMN}") - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals[f"{GENERATOR}"][0])) - gc.do_quit("") - src_stats = { - stat["name"]: stat["query"] - for stat in gc.config["src-stats"] - } - self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) - self.assertTrue(gc.config["tables"]["string"]["primary_private"]) - - def test_empty_tables_are_not_configured(self): - """ Test that tables marked as empty are not configured. """ - config = { - "tables": { - "string": { - "num_rows_per_pass": 0, - } - }, - } - with self._get_cmd(copy.deepcopy(config)) as gc: - gc.do_tables("") - table_names = { m[1][0] for m in gc.messages } - self.assertIn("model", table_names) - self.assertNotIn("string", table_names) - - -class GeneratorsOutputTests(GeneratesDBTestCase): - """ Testing choice generation. """ - dump_file_path = "choice.sql" - database_name = "numbers" - schema_name = "public" - - def setUp(self) -> None: - super().setUp() - ChoiceGeneratorFactory.SAMPLE_COUNT = 500 - ChoiceGeneratorFactory.SUPPRESS_COUNT = 5 - - def _get_cmd(self, config) -> TestGeneratorCmd: - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_sampled_choice(self): - """ Test that suppression works for choice and zipf_choice. """ - table_name = "number_table" - with self._get_cmd({}) as gc: - gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.choice", proposals) - self.assertIn("dist_gen.zipf_choice", proposals) - self.assertIn("dist_gen.choice [sampled]", proposals) - self.assertIn("dist_gen.zipf_choice [sampled]", proposals) - self.assertIn("dist_gen.choice [sampled and suppressed]", proposals) - self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) - gc.do_set(str(proposals["dist_gen.choice [sampled and suppressed]"][0])) - gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.choice", proposals) - self.assertIn("dist_gen.zipf_choice", proposals) - self.assertIn("dist_gen.choice [sampled]", proposals) - self.assertIn("dist_gen.zipf_choice [sampled]", proposals) - self.assertIn("dist_gen.choice [sampled and suppressed]", proposals) - self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) - gc.do_set(str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0])) - gc.do_next("number_table.three") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.choice", proposals) - self.assertIn("dist_gen.zipf_choice", proposals) - self.assertIn("dist_gen.choice [sampled]", proposals) - self.assertIn("dist_gen.zipf_choice [sampled]", proposals) - self.assertNotIn("dist_gen.choice [sampled and suppressed]", proposals) - self.assertNotIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) - gc.do_set(str(proposals["dist_gen.choice [sampled]"][0])) - gc.do_quit("") - self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - ones = set() - twos = set() - threes = set() - for row in rows: - ones.add(row.one) - twos.add(row.two) - threes.add(row.three) - # all generation possibilities should be present - self.assertSetEqual(ones, {1, 4}) - self.assertSetEqual(twos, {2, 3}) - self.assertSetEqual(threes, {1, 2, 3, 4, 5}) - - def test_create_with_choice(self): - """ Smoke test normal choice works. """ - table_name = "number_table" - with self._get_cmd({}) as gc: - gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals["dist_gen.choice"][0])) - gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) - gc.do_quit("") - self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - ones = set() - twos = set() - for row in rows: - ones.add(row.one) - twos.add(row.two) - # all generation possibilities should be present - self.assertSetEqual(ones, {1, 2, 3, 4, 5}) - self.assertSetEqual(twos, {1, 2, 3, 4, 5}) - - def test_create_with_weighted_choice(self): - """ Smoke test weighted choice. """ - table_name = "number_table" - with self._get_cmd({}) as gc: - gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) - prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] - self.assertSubset(set(prop[2]), {"1", "4"}) - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" - self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 4}) - gc.do_set(str(prop[0])) - gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) - prop = proposals["dist_gen.weighted_choice"] - self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. dist_gen.weighted_choice" - self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 2, 3, 4, 5}) - gc.do_set(str(prop[0])) - gc.do_next("number_table.three") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertNotIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) - prop = proposals["dist_gen.weighted_choice [sampled]"] - self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled]" - self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 2, 3, 4, 5}) - gc.do_set(str(prop[0])) - gc.do_quit("") - self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - ones = set() - twos = set() - threes = set() - for row in rows: - ones.add(row.one) - twos.add(row.two) - threes.add(row.three) - # all generation possibilities should be present - self.assertSetEqual(ones, {1, 4}) - self.assertSetEqual(twos, {1, 2, 3, 4, 5}) - self.assertSetEqual(threes, {1, 2, 3, 4, 5}) - - -class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): - """ MissingnessCmd but mocked """ - -class ConfigureMissingnessTests(RequiresDBTestCase): - """ Testing configure-missing. """ - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config) -> TestMissingnessCmd: - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_set_missingness_to_sampled(self): - """ Test that we can set one table to sampled missingness. """ - with self._get_cmd({}) as mc: - TABLE = "signature_model" - mc.do_next(TABLE) - mc.do_counts("") - self.assertListEqual(mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})]) - self.assertListEqual(mc.rows, [['player_id', 3], ['based_on', 2]]) - mc.do_sampled("") - mc.do_quit("") - self.assertDictEqual( - mc.config, - { "tables": {TABLE: {"missingness_generators": [{ - "columns": ["player_id", "based_on"], - "kwargs": {"patterns": 'SRC_STATS["missing_auto__signature_model__0"]'}, - "name": "column_presence.sampled", - }]}}, - "src-stats": [{ - "name": "missing_auto__signature_model__0", - "query": ("SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null") - }] - } - ) - - -class ConfigureMissingnessTests(GeneratesDBTestCase): - """ Testing configure-missing with generation. """ - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config) -> TestMissingnessCmd: - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_missingness(self): - """ Test that we can sample real missingness and reproduce it. """ - random.seed(45) - # Configure the missingness - table_name = "signature_model" - with self._get_cmd({}) as mc: - mc.do_next(table_name) - mc.do_sampled("") - mc.do_quit("") - config = mc.config - self.generate_data(config, num_passes=100) - # Test that each missingness pattern is present in the database - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).mappings().fetchall() - patterns: set[int] = set() - for row in rows: - p = 0 if row["player_id"] is None else 1 - b = 0 if row["based_on"] is None else 2 - patterns.add(p + b) - # all pattern possibilities should be present - self.assertSetEqual(patterns, {0, 1, 2, 3}) - - -class GeneratorTests(GeneratesDBTestCase): - """ Testing configure-generators with generation. """ - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config) -> TestGeneratorCmd: - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_set_null(self): - """ Test that we can sample real missingness and reproduce it. """ - with self._get_cmd({}) as gc: - gc.do_next("string.position") - gc.do_set("dist_gen.constant") - self.assertListEqual(gc.messages, []) - gc.reset() - gc.do_next("string.frequency") - gc.do_set("dist_gen.constant") - self.assertListEqual(gc.messages, []) - gc.reset() - gc.do_next("signature_model.name") - gc.do_set("dist_gen.constant") - self.assertListEqual(gc.messages, []) - gc.reset() - gc.do_next("signature_model.based_on") - gc.do_set("dist_gen.constant") - # we have got to the end of the columns, but shouldn't have any errors - self.assertListEqual(gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})]) - gc.reset() - gc.do_quit("") - config = gc.config - self.generate_data(config, num_passes=3) - # Test that each missingness pattern is present in the database - with self.engine.connect() as conn: - stmt = select(self.metadata.tables["string"].c["position", "frequency"]) - rows = conn.execute(stmt).fetchall() - count = 0 - for row in rows: - count += 1 - self.assertEqual(row.position, 0) - self.assertEqual(row.frequency, 0.0) - self.assertEqual(count, 3) - stmt = select(self.metadata.tables["signature_model"].c["name", "based_on"]) - rows = conn.execute(stmt).fetchall() - count = 0 - for row in rows: - count += 1 - self.assertEqual(row.name, "") - self.assertIsNone(row.based_on) - self.assertEqual(count, 3) - - def test_dist_gen_sampled_produces_ordered_src_stats(self): - """ Tests that choosing a sampled choice generator produces ordered src stats """ - with self._get_cmd({}) as gc: - gc.do_next("signature_model.player_id") - gc.do_set("dist_gen.zipf_choice [sampled]") - gc.do_next("signature_model.based_on") - gc.do_set("dist_gen.zipf_choice [sampled]") - gc.do_quit("") - config = gc.config - self.set_configuration(config) - src_stats = self.get_src_stats(config) - player_ids = [ - s["value"] - for s in src_stats["auto__signature_model__player_id"]["results"] - ] - self.assertListEqual(player_ids, [2, 3, 1]) - based_ons = [ - s["value"] - for s in src_stats["auto__signature_model__based_on"]["results"] - ] - self.assertListEqual(based_ons, [1, 3, 2]) - - def assertAreTruncatedTo(self, xs, length): - maxlen = 0 - for x in xs: - newlen = len(x.strip("'\"")) - self.assertLessEqual(newlen, length) - maxlen = max(maxlen, newlen) - self.assertEqual(maxlen, length) - - def test_varchar_ns_are_truncated(self): - """ Tests that mimesis generators for VARCHAR(N) truncate to N characters """ - GENERATOR = "generic.text.quote" - TABLE = "signature_model" - COLUMN = "name" - with self._get_cmd({}) as gc: - gc.do_next(f"{TABLE}.{COLUMN}") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - quotes = [k for k in proposals.keys() if k.startswith(GENERATOR)] - self.assertEqual(len(quotes), 1) - prop = proposals[quotes[0]] - self.assertAreTruncatedTo(prop[2], 20) - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {quotes[0]}" - gc.do_set(str(prop[0])) - self.assertIn(col_heading, gc.columns) - self.assertAreTruncatedTo(gc.columns[col_heading], 20) - gc.do_quit("") - config = gc.config - self.generate_data(config, num_passes=15) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[TABLE].c[COLUMN]) - rows = conn.execute(stmt).scalars().fetchall() - self.assertAreTruncatedTo(rows, 20) - - -@dataclass -class Stat: - n: int=0 - x: float=0 - x2: float=0 - - def add(self, x: float) -> None: - self.n += 1 - self.x += x - self.x2 += x * x - - def count(self) -> int: - return self.n - - def x_mean(self) -> float: - return self.x / self.n - - def x_var(self) -> float: - x = self.x - return (self.x2 - x*x/self.n)/(self.n - 1) - - -@dataclass -class Correlation(Stat): - y: float=0 - y2: float=0 - xy: float=0 - - def add(self, x: float, y: float) -> None: - self.n += 1 - self.x += x - self.x2 += x * x - self.y += y - self.y2 += y * y - self.xy += x * y - - def y_mean(self) -> float: - return self.y / self.n - - def y_var(self) -> float: - y = self.y - return (self.y2 - y*y/self.n)/(self.n - 1) - - def covar(self) -> float: - return (self.xy - self.x*self.y/self.n)/(self.n - 1) - - -class NullPartitionedTests(GeneratesDBTestCase): - """ Testing null-partitioned grouped multivariate generation. """ - dump_file_path = "eav.sql" - database_name = "eav" - schema_name = "public" - - def setUp(self) -> None: - super().setUp() - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 - NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 - - def _get_cmd(self, config) -> TestGeneratorCmd: - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_null_partitioned_grouped_multivariate(self): - """ Test EAV for all columns. """ - table_name = "measurement" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns(gc, table_name, ["type", "first_value", "second_value", "third_value"]) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - gc.reset() - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - one_count = 0 - one_yes_count = 0 - two = Correlation() - three = Correlation() - four = Correlation() - fish = Stat() - fowl = Stat() - for row in rows: - if row.type == 1: - # yes or no - self.assertIsNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'yes', 'no'}) - one_count += 1 - if row.third_value == 'yes': - one_yes_count += 1 - elif row.type == 2: - # positive correlation around 1.4, 1.8 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - two.add(row.first_value, row.second_value) - elif row.type == 3: - # negative correlation around 11.8, 12.1 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - three.add(row.first_value, row.second_value) - elif row.type == 4: - # positive correlation around 21.4, 23.4 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - four.add(row.first_value, row.second_value) - elif row.type == 5: - self.assertIn(row.third_value, {'fish', 'fowl'}) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - if row.third_value == 'fish': - # mean 8.1 and sd 0.755 - fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - fowl.add(row.first_value) - # type 1 - self.assertAlmostEqual(one_count, generate_count * 5 / 20, delta=generate_count * 0.4) - # about 40% are yes - self.assertAlmostEqual(one_yes_count / one_count, 0.4, delta=generate_count * 0.4) - # type 2 - self.assertAlmostEqual(two.count(), generate_count * 3 / 20, delta=generate_count * 0.5) - self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) - self.assertAlmostEqual(two.x_var(), 0.14, delta=0.2) - self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(two.y_var(), 0.05, delta=0.1) - self.assertAlmostEqual(two.covar(), 0.07, delta=0.1) - # type 3 - self.assertAlmostEqual(three.count(), generate_count * 3 / 20, delta=generate_count * 0.2) - self.assertAlmostEqual(three.covar(), -1.39, delta=1.4) - # type 4 - self.assertAlmostEqual(four.count(), generate_count * 3 / 20, delta=generate_count * 0.2) - self.assertAlmostEqual(four.covar(), 2.22, delta=2.3) - # type 5/fish - self.assertAlmostEqual(fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2) - self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) - # type 5/fowl - self.assertAlmostEqual(fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2) - self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) - - def populate_measurement_type_vocab(self): - """ Add a vocab table without messing around with files """ - table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: - conn.execute(insert(table).values({"id": 1, "name": "agreement"})) - conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) - conn.execute(insert(table).values({"id": 3, "name": "velocity"})) - conn.execute(insert(table).values({"id": 4, "name": "position"})) - conn.execute(insert(table).values({"id": 5, "name": "matter"})) - conn.commit() - - def merge_columns(self, gc: TestGeneratorCmd, table: str, columns: list[str]) -> None: - """ Merge columns in a table """ - gc.do_next(f"{table}.{columns[0]}") - for col in columns[1:]: - gc.do_merge(col) - gc.reset() - - def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): - """ Test EAV for all columns with sampled and suppressed generation. """ - table_name = "measurement" - table2_name = "observation" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns(gc, table_name, ["type", "first_value", "second_value", "third_value"]) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn("null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled and suppressed]" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - self.merge_columns(gc, table2_name, ["type", "first_value", "second_value", "third_value"]) - gc.do_propose("") - proposals = gc.get_proposals() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - one_count = 0 - one_yes_count = 0 - fish = Stat() - fowl = Stat() - types: set[int] = set() - for row in rows: - types.add(row.type) - if row.type == 1: - # yes or no - self.assertIsNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'yes', 'no'}) - if row.third_value == 'yes': - one_yes_count += 1 - one_count += 1 - elif row.type == 5: - self.assertIn(row.third_value, {'fish', 'fowl'}) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - if row.third_value == 'fish': - # mean 8.1 and sd 0.755 - fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - fowl.add(row.first_value) - self.assertSubset(types, {1, 2, 3, 4, 5}) - self.assertEqual(len(types), 4) - self.assertSubset({1, 5}, types) - # type 1 - self.assertAlmostEqual(one_count, generate_count * 5 / 11, delta=generate_count * 0.4) - # about 40% are yes - self.assertAlmostEqual(one_yes_count / one_count, 0.4, delta=generate_count * 0.4) - # type 5/fish - self.assertAlmostEqual(fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2) - self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) - # type 5/fowl - self.assertAlmostEqual(fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2) - self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) - stmt = select(self.metadata.tables[table2_name]) - rows = conn.execute(stmt).fetchall() - firsts = Stat() - for row in rows: - types.add(row.type) - self.assertEqual(row.type, 1) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'ham', 'eggs'}) - firsts.add(row.first_value) - self.assertEqual(firsts.count(), 800) - self.assertAlmostEqual(firsts.x_mean(), 1.3, delta = generate_count * 0.3) - - - def test_create_with_null_partitioned_grouped_sampled_only(self): - """ Test EAV for all columns with sampled generation but no suppression. """ - table_name = "measurement" - table2_name = "observation" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns(gc, table_name, ["type", "first_value", "second_value", "third_value"]) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn("null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal [sampled and suppressed]", proposals) - self.assertIn("null-partitioned grouped_multivariate_lognormal [sampled]", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - self.merge_columns(gc, table2_name, ["type", "first_value", "second_value", "third_value"]) - gc.do_propose("") - proposals = gc.get_proposals() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - self.assertSubset({row.type for row in rows}, {1, 2, 3, 4, 5}) - stmt = select(self.metadata.tables[table2_name]) - rows = conn.execute(stmt).fetchall() - self.assertEqual({row.third_value for row in rows}, {"ham", "eggs", "cheese"}) - - - def test_create_with_null_partitioned_grouped_sampled_tiny(self): - """ - Test EAV for all columns with sampled generation that only gets a tiny sample. - """ - # five will ensure that at least one group will have two elements in it, - # but all three cannot. - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 - table_name = "observation" - generate_count = 100 - with self._get_cmd({}) as gc: - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" - self.merge_columns(gc, table_name, ["type", "first_value", "second_value", "third_value"]) - gc.do_propose("") - proposals = gc.get_proposals() - #breakpoint() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - # we should only have one or two of "ham", "eggs" and "cheese" represented - foods = {row.third_value for row in rows} - self.assertSubset(foods, {"ham", "eggs", "cheese"}) - self.assertLess(len(foods), 3) - - -class NonInteractiveTests(RequiresDBTestCase): - """ - Test the --spec SPEC_FILE option of configure-generators - """ - dump_file_path = "eav.sql" - database_name = "eav" - schema_name = "public" - - @patch("datafaker.interactive.Path") - @patch("datafaker.interactive.csv.reader", return_value=iter([ - ["observation", "type", "dist_gen.weighted_choice [sampled]"], - ["observation", "first_value", "dist_gen.weighted_choice", "dist_gen.constant"], - ["observation", "second_value", "dist_gen.weighted_choice", "dist_gen.weighted_choice [sampled]", "dist_gen.constant"], - ["observation", "third_value", "dist_gen.weighted_choice"], - ])) - def test_non_interactive_configure_generators(self, mock_csv_reader: MagicMock, mock_path: MagicMock): - """ - test that we can set generators from a CSV file - """ - config = {} - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators(self.dsn, self.schema_name, self.metadata, config, spec_csv) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg["name"] - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") - self.assertEqual(row_gens["observation['first_value']"], "dist_gen.weighted_choice") - self.assertEqual(row_gens["observation['second_value']"], "dist_gen.constant") - self.assertEqual(row_gens["observation['third_value']"], "dist_gen.weighted_choice") - - @patch("datafaker.interactive.Path") - @patch("datafaker.interactive.csv.reader", return_value=iter([ - [ - "observation", - "type first_value second_value third_value", - "null-partitioned grouped_multivariate_lognormal", - ], - ])) - def test_non_interactive_configure_null_partitioned(self, mock_csv_reader: MagicMock, mock_path: MagicMock): - """ - test that we can set multi-column generators from a CSV file - """ - config = {} - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators(self.dsn, self.schema_name, self.metadata, config, spec_csv) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual( - row_gens["observation['first_value', 'second_value', 'third_value', 'type']"]["name"], - "dist_gen.alternatives", - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["name"], - '"with_constants_at"', - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], - '"grouped_multivariate_lognormal"', - ) - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - [ - "observation", - "type first_value second_value third_value", - "null-partitioned grouped_multivariate_lognormal", - ], - ] - ), - ) - def test_non_interactive_configure_null_partitioned_where_existing_merges( - self, _mock_csv_reader: MagicMock, _mock_path: MagicMock - ) -> None: - """ - test that we can set multi-column generators from a CSV file, - but where there are already multi-column generators configured - that will have to be unmerged. - """ - config = { - "tables": { - "observation": { - "row_generators": [{ - "name": "arbitrary_gen", - "columns_assigned": [ - "type", - "second_value", - "first_value", - ], - }], - }, - }, - } - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["name"], - "dist_gen.alternatives", - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["name"], - '"with_constants_at"', - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], - '"grouped_multivariate_lognormal"', - ) diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py new file mode 100644 index 00000000..104f0d3a --- /dev/null +++ b/tests/test_interactive_generators.py @@ -0,0 +1,863 @@ +""" Tests for the configure-generators command. """ +import copy +import re +from collections.abc import MutableMapping +from typing import Any, Iterable + +from sqlalchemy import Connection, MetaData, select + +from datafaker.generators.choice import ChoiceGeneratorFactory +from datafaker.interactive.generators import GeneratorCmd +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin + + +class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): + """GeneratorCmd but mocked""" + + def get_proposals(self) -> dict[str, tuple[int, str, list[str]]]: + """ + Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples]) + """ + return { + kw["name"]: (kw["index"], kw["fit"], kw["sample"].split("; ")) + for (s, _, kw) in self.messages + if s == self.PROPOSE_GENERATOR_SAMPLE_TEXT + } + + +class ConfigureGeneratorsTests(RequiresDBTestCase): + """Testing configure-generators.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the command we are using for this test case.""" + return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_null_configuration(self) -> None: + """Test that the tables having null configuration does not break.""" + config = { + "tables": None, + } + with self._get_cmd(config) as gc: + table = "model" + gc.do_next(f"{table}.name") + gc.do_propose("") + gc.do_compare("") + gc.do_set("1") + gc.do_quit("") + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) + + def test_null_table_configuration(self) -> None: + """Test that a table having null configuration does not break.""" + config = { + "tables": { + "model": None, + } + } + with self._get_cmd(config) as gc: + table = "model" + gc.do_next(f"{table}.name") + gc.do_propose("") + gc.do_set("1") + gc.do_quit("") + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) + + def test_prompts(self) -> None: + """Test that the prompts follow the names of the columns and assigned generators.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as gc: + for table_name, table_meta in self.metadata.tables.items(): + for column_name, column_meta in table_meta.columns.items(): + self.assertIn(table_name, gc.prompt) + self.assertIn(column_name, gc.prompt) + if column_meta.primary_key: + self.assertIn("[pk]", gc.prompt) + else: + self.assertNotIn("[pk]", gc.prompt) + gc.do_next("") + self.assertListEqual( + gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})] + ) + gc.reset() + for table_name, table_meta in reversed(list(self.metadata.tables.items())): + for column_name, column_meta in reversed( + list(table_meta.columns.items()) + ): + self.assertIn(table_name, gc.prompt) + self.assertIn(column_name, gc.prompt) + if column_meta.primary_key: + self.assertIn("[pk]", gc.prompt) + else: + self.assertNotIn("[pk]", gc.prompt) + gc.do_previous("") + self.assertListEqual( + gc.messages, [(GeneratorCmd.ERROR_ALREADY_AT_START, (), {})] + ) + gc.reset() + bad_table_name = "notarealtable" + gc.do_next(bad_table_name) + self.assertListEqual( + gc.messages, + [(GeneratorCmd.ERROR_NO_SUCH_TABLE_OR_COLUMN, (bad_table_name,), {})], + ) + gc.reset() + + def test_set_generator_mimesis(self) -> None: + """Test that we can set one generator to a mimesis generator.""" + with self._get_cmd({}) as gc: + table = "model" + column = "name" + generator = "person.first_name" + gc.do_next(f"{table}.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[f"generic.{generator}"][0])) + gc.do_quit("") + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) + self.assertDictEqual( + gc.config["tables"][table]["row_generators"][0], + {"name": f"generic.{generator}", "columns_assigned": [column]}, + ) + + def test_set_generator_distribution(self) -> None: + """Test that we can set one generator to gaussian.""" + with self._get_cmd({}) as gc: + table = "string" + column = "frequency" + generator = "dist_gen.normal" + gc.do_next(f"{table}.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[generator][0])) + gc.do_quit("") + row_gens = gc.config["tables"][table]["row_generators"] + self.assertEqual(len(row_gens), 1) + row_gen = row_gens[0] + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column]) + self.assertDictEqual( + row_gen["kwargs"], + { + "mean": f'SRC_STATS["auto__{table}"]["results"][0]["mean__{column}"]', + "sd": f'SRC_STATS["auto__{table}"]["results"][0]["stddev__{column}"]', + }, + ) + self.assertEqual(len(gc.config["src-stats"]), 1) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) + self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") + self.assertEqual( + gc.config["src-stats"][0]["query"], + ( + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column})" + f" AS stddev__{column} FROM {table}" + ), + ) + + def test_set_generator_distribution_directly(self) -> None: + """Test that we can set one generator to gaussian without going through propose.""" + with self._get_cmd({}) as gc: + table = "string" + column = "frequency" + generator = "dist_gen.normal" + gc.do_next(f"{table}.{column}") + gc.reset() + gc.do_set(generator) + self.assertListEqual(gc.messages, []) + gc.do_quit("") + self.assertEqual(len(gc.config["src-stats"]), 1) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) + self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") + self.assertEqual( + gc.config["src-stats"][0]["query"], + ( + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column})" + f" AS stddev__{column} FROM {table}" + ), + ) + + def test_set_generator_choice(self) -> None: + """Test that we can set one generator to uniform choice.""" + with self._get_cmd({}) as gc: + table = "string" + column = "frequency" + generator = "dist_gen.choice" + gc.do_next(f"{table}.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[generator][0])) + gc.do_quit("") + row_gens = gc.config["tables"][table]["row_generators"] + self.assertEqual(len(row_gens), 1) + row_gen = row_gens[0] + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column]) + self.assertDictEqual( + row_gen["kwargs"], + { + "a": f'SRC_STATS["auto__{table}__{column}"]["results"]', + }, + ) + self.assertEqual(len(gc.config["src-stats"]), 1) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) + self.assertEqual( + gc.config["src-stats"][0]["name"], f"auto__{table}__{column}" + ) + self.assertEqual( + gc.config["src-stats"][0]["query"], + ( + f"SELECT {column} AS value FROM {table}" + f" WHERE {column} IS NOT NULL" + f" GROUP BY value ORDER BY COUNT({column}) DESC" + ), + ) + + def test_weighted_choice_generator_generates_choices(self) -> None: + """Test that propose and compare show weighted_choice's values.""" + with self._get_cmd({}) as gc: + table = "string" + column = "position" + generator = "dist_gen.weighted_choice" + values = {1, 2, 3, 4, 5, 6} + gc.do_next(f"{table}.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gen_proposal = proposals[generator] + self.assert_subset(set(gen_proposal[2]), {str(v) for v in values}) + gc.do_compare(str(gen_proposal[0])) + col_heading = f"{gen_proposal[0]}. {generator}" + self.assertIn(col_heading, gc.columns) + self.assert_subset(set(gc.columns[col_heading]), values) + + def test_merge_columns(self) -> None: + """Test that we can merge columns and set a multivariate generator""" + table = "string" + column_1 = "frequency" + column_2 = "position" + generator_to_discard = "dist_gen.choice" + generator = "dist_gen.multivariate_normal" + with self._get_cmd({}) as gc: + gc.do_next(f"{table}.{column_2}") + gc.do_propose("") + proposals = gc.get_proposals() + # set a generator, but this should not exist after merging + gc.do_set(str(proposals[generator_to_discard][0])) + gc.do_next(f"{table}.{column_1}") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) + gc.do_propose("") + proposals = gc.get_proposals() + # set a generator, but this should not exist either + gc.do_set(str(proposals[generator_to_discard][0])) + gc.do_previous("") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) + gc.do_merge(column_2) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) + gc.reset() + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[generator][0])) + gc.do_quit("") + row_gens = gc.config["tables"][table]["row_generators"] + self.assertEqual(len(row_gens), 1) + row_gen = row_gens[0] + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column_1, column_2]) + + def test_unmerge_columns(self) -> None: + """Test that we can unmerge columns and generators are removed""" + table = "string" + column_1 = "frequency" + column_2 = "position" + column_3 = "model_id" + remaining_gen = "gen3" + config = { + "tables": { + table: { + "row_generators": [ + {"name": "gen1", "columns_assigned": [column_1, column_2]}, + {"name": remaining_gen, "columns_assigned": [column_3]}, + ] + } + } + } + with self._get_cmd(config) as gc: + gc.do_next(f"{table}.{column_2}") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) + gc.do_unmerge(column_1) + self.assertIn(table, gc.prompt) + self.assertNotIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) + # Next generator should be the unmerged one + gc.do_next("") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) + gc.do_quit("") + # Both generators should have disappeared + row_gens = gc.config["tables"][table]["row_generators"] + self.assertEqual(len(row_gens), 1) + row_gen = row_gens[0] + self.assertEqual(row_gen["name"], remaining_gen) + self.assertListEqual(row_gen["columns_assigned"], [column_3]) + + def test_old_generators_remain(self) -> None: + """Test that we can set one generator and keep an old one.""" + config = { + "tables": { + "string": { + "row_generators": [ + { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', + }, + } + ] + } + }, + "src-stats": [ + { + "name": "auto__string", + "query": ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), + } + ], + } + with self._get_cmd(config) as gc: + table = "model" + column = "name" + generator = "person.first_name" + gc.do_next(f"{table}.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[f"generic.{generator}"][0])) + gc.do_quit("") + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) + self.assertDictEqual( + gc.config["tables"][table]["row_generators"][0], + {"name": f"generic.{generator}", "columns_assigned": [column]}, + ) + row_gens = gc.config["tables"]["string"]["row_generators"] + self.assertEqual(len(row_gens), 1) + row_gen = row_gens[0] + self.assertEqual(row_gen["name"], "dist_gen.normal") + self.assertListEqual(row_gen["columns_assigned"], ["frequency"]) + self.assertDictEqual( + row_gen["kwargs"], + { + "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', + }, + ) + self.assertEqual(len(gc.config["src-stats"]), 1) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) + self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") + self.assertEqual( + gc.config["src-stats"][0]["query"], + ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), + ) + + def test_aggregate_queries_merge(self) -> None: + """ + Test that we can set a generator that requires select aggregate clauses + and keep an old one, resulting in a merged query. + """ + rg = { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', + }, + } + config = { + "tables": {"string": {"row_generators": [rg]}}, + "src-stats": [ + { + "name": "auto__string", + "query": ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), + } + ], + } + with self._get_cmd(copy.deepcopy(config)) as gc: + column = "position" + generator = "dist_gen.uniform_ms" + gc.do_next(f"string.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[f"{generator}"][0])) + gc.do_quit("") + row_gens: list[dict[str, Any]] = gc.config["tables"]["string"][ + "row_generators" + ] + self.assertEqual(len(row_gens), 2) + if row_gens[0]["name"] == generator: + row_gen0 = row_gens[0] + row_gen1 = row_gens[1] + else: + row_gen0 = row_gens[1] + row_gen1 = row_gens[0] + self.assertEqual(row_gen0["name"], generator) + self.assertEqual(row_gen1["name"], "dist_gen.normal") + self.assertListEqual(row_gen0["columns_assigned"], [column]) + self.assertDictEqual( + row_gen0["kwargs"], + { + "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{column}"]', + "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{column}"]', + }, + ) + self.assertListEqual(row_gen1["columns_assigned"], ["frequency"]) + self.assertDictEqual( + row_gen1["kwargs"], + { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', + }, + ) + self.assertEqual(len(gc.config["src-stats"]), 1) + self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") + select_match = re.match( + r"SELECT (.*) FROM string", gc.config["src-stats"][0]["query"] + ) + assert ( + select_match is not None + ), "src_stats[0].query is not an aggregate select" + self.assertSetEqual( + set(select_match.group(1).split(", ")), + { + "AVG(frequency) AS mean__frequency", + "STDDEV(frequency) AS stddev__frequency", + f"AVG({column}) AS mean__{column}", + f"STDDEV({column}) AS stddev__{column}", + }, + ) + + def test_next_completion(self) -> None: + """Test tab completion for the next command.""" + with self._get_cmd({}) as gc: + self.assertSetEqual( + set(gc.complete_next("m", "next m", 5, 6)), + {"manufacturer", "model"}, + ) + self.assertSetEqual( + set(gc.complete_next("model", "next model", 5, 10)), + {"model", "model."}, + ) + self.assertSetEqual( + set(gc.complete_next("string.", "next string.", 5, 11)), + {"string.id", "string.model_id", "string.position", "string.frequency"}, + ) + self.assertSetEqual( + set(gc.complete_next("string.p", "next string.p", 5, 12)), + {"string.position"}, + ) + self.assertListEqual( + gc.complete_next("string.q", "next string.q", 5, 12), [] + ) + self.assertListEqual(gc.complete_next("ww", "next ww", 5, 7), []) + + def test_compare_reports_privacy(self) -> None: + """ + Test that compare reports whether the current table is primary private, + secondary private or not private. + """ + config = { + "tables": { + "model": { + "primary_private": True, + } + }, + } + with self._get_cmd(config) as gc: + gc.do_next("manufacturer") + gc.reset() + gc.do_compare("") + (text, args, _kwargs) = gc.messages[0] + self.assertEqual(text, gc.NOT_PRIVATE_TEXT) + gc.do_next("model") + gc.reset() + gc.do_compare("") + (text, args, _kwargs) = gc.messages[0] + self.assertEqual(text, gc.PRIMARY_PRIVATE_TEXT) + gc.do_next("string") + gc.reset() + gc.do_compare("") + (text, args, _kwargs) = gc.messages[0] + self.assertEqual(text, gc.SECONDARY_PRIVATE_TEXT) + self.assertSequenceEqual(args, [["model"]]) + + def test_existing_configuration_remains(self) -> None: + """ + Test setting a generator does not remove other information. + """ + config: MutableMapping[str, Any] = { + "tables": { + "string": { + "primary_private": True, + } + }, + "src-stats": [ + { + "name": "kraken", + "query": "SELECT MAX(frequency) AS max_frequency FROM string", + } + ], + } + with self._get_cmd(config) as gc: + column = "position" + generator = "dist_gen.uniform_ms" + gc.do_next(f"string.{column}") + gc.do_propose("") + proposals = gc.get_proposals() + gc.do_set(str(proposals[f"{generator}"][0])) + gc.do_quit("") + src_stats = {stat["name"]: stat["query"] for stat in gc.config["src-stats"]} + self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) + self.assertTrue(gc.config["tables"]["string"]["primary_private"]) + + def test_empty_tables_are_not_configured(self) -> None: + """Test that tables marked as empty are not configured.""" + config = { + "tables": { + "string": { + "num_rows_per_pass": 0, + } + }, + } + with self._get_cmd(copy.deepcopy(config)) as gc: + gc.do_tables("") + table_names = {m[1][0] for m in gc.messages} + self.assertIn("model", table_names) + self.assertNotIn("string", table_names) + + +class ChoiceMeasurementTableStats: + """Measure the data in the ``choice.sql`` schema.""" + + def __init__(self, metadata: MetaData, connection: Connection): + """Get the data and do the analysis.""" + stmt = select(metadata.tables["number_table"]) + rows = connection.execute(stmt).fetchall() + self.ones: set[int] = set() + self.twos: set[int] = set() + self.threes: set[int] = set() + for row in rows: + self.ones.add(row.one) + self.twos.add(row.two) + self.threes.add(row.three) + + +class GeneratorsOutputTests(GeneratesDBTestCase): + """Testing choice generation.""" + + dump_file_path = "choice.sql" + database_name = "numbers" + schema_name = "public" + + def setUp(self) -> None: + super().setUp() + ChoiceGeneratorFactory.SAMPLE_COUNT = 500 + ChoiceGeneratorFactory.SUPPRESS_COUNT = 5 + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + + def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: + gc.reset() + gc.do_propose("") + return gc.get_proposals() + + def test_create_with_sampled_choice(self) -> None: + """Test that suppression works for choice and zipf_choice.""" + with self._get_cmd({}) as gc: + gc.do_next("number_table.one") + proposals = self._propose(gc) + self.assertIn("dist_gen.choice", proposals) + self.assertIn("dist_gen.zipf_choice", proposals) + self.assertIn("dist_gen.choice [sampled]", proposals) + self.assertIn("dist_gen.zipf_choice [sampled]", proposals) + self.assertIn("dist_gen.choice [sampled and suppressed]", proposals) + self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) + gc.do_set(str(proposals["dist_gen.choice [sampled and suppressed]"][0])) + gc.do_next("number_table.two") + proposals = self._propose(gc) + self.assertIn("dist_gen.choice", proposals) + self.assertIn("dist_gen.zipf_choice", proposals) + self.assertIn("dist_gen.choice [sampled]", proposals) + self.assertIn("dist_gen.zipf_choice [sampled]", proposals) + self.assertIn("dist_gen.choice [sampled and suppressed]", proposals) + self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) + gc.do_set( + str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0]) + ) + gc.do_next("number_table.three") + proposals = self._propose(gc) + self.assertIn("dist_gen.choice", proposals) + self.assertIn("dist_gen.zipf_choice", proposals) + self.assertIn("dist_gen.choice [sampled]", proposals) + self.assertIn("dist_gen.zipf_choice [sampled]", proposals) + self.assertNotIn("dist_gen.choice [sampled and suppressed]", proposals) + self.assertNotIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) + gc.do_set(str(proposals["dist_gen.choice [sampled]"][0])) + gc.do_quit("") + self.generate_data(gc.config, num_passes=200) + # all generation possibilities should be present + with self.sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {2, 3}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) + + def test_create_with_choice(self) -> None: + """Smoke test normal choice works.""" + table_name = "number_table" + with self._get_cmd({}) as gc: + gc.do_next("number_table.one") + proposals = self._propose(gc) + gc.do_set(str(proposals["dist_gen.choice"][0])) + gc.do_next("number_table.two") + proposals = self._propose(gc) + gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) + gc.do_quit("") + self.generate_data(gc.config, num_passes=200) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + ones = set() + twos = set() + for row in rows: + ones.add(row.one) + twos.add(row.two) + # all generation possibilities should be present + self.assertSetEqual(ones, {1, 2, 3, 4, 5}) + self.assertSetEqual(twos, {1, 2, 3, 4, 5}) + + def test_create_with_weighted_choice(self) -> None: + """Smoke test weighted choice.""" + with self._get_cmd({}) as gc: + gc.do_next("number_table.one") + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [suppressed]", + "dist_gen.weighted_choice [sampled and suppressed]", + }, + set(proposals), + ) + prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] + self.assert_subset(set(prop[2]), {"1", "4"}) + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = ( + f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" + ) + self.assertIn(col_heading, set(gc.columns.keys())) + col_set: set[int] = set(gc.columns[col_heading]) + self.assert_subset(col_set, {1, 4}) + gc.do_set(str(prop[0])) + gc.do_next("number_table.two") + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [suppressed]", + "dist_gen.weighted_choice [sampled and suppressed]", + }, + set(proposals), + ) + prop = proposals["dist_gen.weighted_choice"] + self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. dist_gen.weighted_choice" + self.assertIn(col_heading, set(gc.columns.keys())) + col_set2: set[int] = set(gc.columns[col_heading]) + self.assert_subset(col_set2, {1, 2, 3, 4, 5}) + gc.do_set(str(prop[0])) + gc.do_next("number_table.three") + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + }, + set(proposals), + ) + self.assertNotIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) + prop = proposals["dist_gen.weighted_choice [sampled]"] + self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled]" + self.assertIn(col_heading, set(gc.columns.keys())) + col_set3: set[int] = set(gc.columns[col_heading]) + self.assert_subset(col_set3, {1, 2, 3, 4, 5}) + gc.do_set(str(prop[0])) + gc.do_quit("") + self.generate_data(gc.config, num_passes=200) + with self.sync_engine.connect() as conn: + with self.sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + # all generation possibilities should be present + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) + + +class GeneratorTests(GeneratesDBTestCase): + """Testing configure-generators with generation.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """We are using configure-generators.""" + return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_set_null(self) -> None: + """Test that we can sample real missingness and reproduce it.""" + with self._get_cmd({}) as gc: + gc.do_next("string.position") + gc.do_set("dist_gen.constant") + self.assertListEqual(gc.messages, []) + gc.reset() + gc.do_next("string.frequency") + gc.do_set("dist_gen.constant") + self.assertListEqual(gc.messages, []) + gc.reset() + gc.do_next("signature_model.name") + gc.do_set("dist_gen.constant") + self.assertListEqual(gc.messages, []) + gc.reset() + gc.do_next("signature_model.based_on") + gc.do_set("dist_gen.constant") + # we have got to the end of the columns, but shouldn't have any errors + self.assertListEqual( + gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})] + ) + gc.reset() + gc.do_quit("") + config = gc.config + self.generate_data(config, num_passes=3) + # Test that each missingness pattern is present in the database + with self.sync_engine.connect() as conn: + # select(self.metadata.tables["string"].c["position", "frequency"]) would be nicer + # but mypy doesn't like it + stmt = select( + self.metadata.tables["string"].c["position"], + self.metadata.tables["string"].c["frequency"], + ) + rows = conn.execute(stmt).fetchall() + count = 0 + for row in rows: + count += 1 + self.assertEqual(row.position, 0) + self.assertEqual(row.frequency, 0.0) + self.assertEqual(count, 3) + # select(self.metadata.tables["signature_model"].c["name", "based_on"]) would be nicer + # but mypy doesn't like it + stmt = select( + self.metadata.tables["signature_model"].c["name"], + self.metadata.tables["signature_model"].c["based_on"], + ) + rows = conn.execute(stmt).fetchall() + count = 0 + for row in rows: + count += 1 + self.assertEqual(row.name, "") + self.assertIsNone(row.based_on) + self.assertEqual(count, 3) + + def test_dist_gen_sampled_produces_ordered_src_stats(self) -> None: + """Tests that choosing a sampled choice generator produces ordered src stats""" + with self._get_cmd({}) as gc: + gc.do_next("signature_model.player_id") + gc.do_set("dist_gen.zipf_choice [sampled]") + gc.do_next("signature_model.based_on") + gc.do_set("dist_gen.zipf_choice [sampled]") + gc.do_quit("") + config = gc.config + self.set_configuration(config) + src_stats = self.get_src_stats(config) + player_ids = [ + s["value"] for s in src_stats["auto__signature_model__player_id"]["results"] + ] + self.assertListEqual(player_ids, [2, 3, 1]) + based_ons = [ + s["value"] for s in src_stats["auto__signature_model__based_on"]["results"] + ] + self.assertListEqual(based_ons, [1, 3, 2]) + + def assert_are_truncated_to(self, xs: Iterable[str], length: int) -> None: + """ + Check that none of the strings are longer than ``length`` (after + removing surrounding quotes). + """ + maxlen = 0 + for x in xs: + newlen = len(x.strip("'\"")) + self.assertLessEqual(newlen, length) + maxlen = max(maxlen, newlen) + self.assertEqual(maxlen, length) + + def test_varchar_ns_are_truncated(self) -> None: + """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" + generator = "generic.text.quote" + table = "signature_model" + column = "name" + with self._get_cmd({}) as gc: + gc.do_next(f"{table}.{column}") + gc.reset() + gc.do_propose("") + proposals = gc.get_proposals() + quotes = [k for k in proposals.keys() if k.startswith(generator)] + self.assertEqual(len(quotes), 1) + prop = proposals[quotes[0]] + self.assert_are_truncated_to(prop[2], 20) + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {quotes[0]}" + gc.do_set(str(prop[0])) + self.assertIn(col_heading, gc.columns) + self.assert_are_truncated_to(gc.columns[col_heading], 20) + gc.do_quit("") + config = gc.config + self.generate_data(config, num_passes=15) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table].c[column]) + rows = conn.execute(stmt).scalars().fetchall() + self.assert_are_truncated_to(rows, 20) diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py new file mode 100644 index 00000000..3b21b535 --- /dev/null +++ b/tests/test_interactive_generators_partitioned.py @@ -0,0 +1,419 @@ +"""Tests for null-partitioned generators.""" +from collections.abc import MutableMapping +from dataclasses import dataclass +from typing import Any +from unittest import TestCase + +from sqlalchemy import Connection, MetaData, insert, select + +from datafaker.generators import NullPartitionedNormalGeneratorFactory +from tests.test_interactive_generators import TestGeneratorCmd +from tests.utils import GeneratesDBTestCase + + +@dataclass +class Stat: + """Mean and variance calculator.""" + + n: int = 0 + x: float = 0 + x2: float = 0 + + def add(self, x: float) -> None: + """Add one datum.""" + self.n += 1 + self.x += x + self.x2 += x * x + + def count(self) -> int: + """Get the number of data added.""" + return self.n + + def x_mean(self) -> float: + """Get the mean of the added data.""" + return self.x / self.n + + def x_var(self) -> float: + """Get the variance of the added data.""" + x = self.x + return (self.x2 - x * x / self.n) / (self.n - 1) + + +@dataclass +class Correlation(Stat): + """Mean, variance and covariance.""" + + y: float = 0 + y2: float = 0 + xy: float = 0 + + def add2(self, x: float, y: float) -> None: + """Add a 2D data point.""" + self.n += 1 + self.x += x + self.x2 += x * x + self.y += y + self.y2 += y * y + self.xy += x * y + + def y_mean(self) -> float: + """Get the mean of the second parts of the added points.""" + return self.y / self.n + + def y_var(self) -> float: + """Get the variance of the second parts of the added points.""" + y = self.y + return (self.y2 - y * y / self.n) / (self.n - 1) + + def covar(self) -> float: + """Get the covariance of the two parts of the added points.""" + return (self.xy - self.x * self.y / self.n) / (self.n - 1) + + +# pylint: disable=too-many-instance-attributes +class EavMeasurementTableStats: + """The statistics for the Measurement table of eav.sql.""" + + def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: + stmt = select(metadata.tables["measurement"]) + rows = conn.execute(stmt).fetchall() + self.types: set[int] = set() + self.one_count = 0 + self.one_yes_count = 0 + self.two = Correlation() + self.three = Correlation() + self.four = Correlation() + self.fish = Stat() + self.fowl = Stat() + for row in rows: + self.types.add(row.type) + if row.type == 1: + # yes or no + test.assertIsNone(row.first_value) + test.assertIsNone(row.second_value) + test.assertIn(row.third_value, {"yes", "no"}) + self.one_count += 1 + if row.third_value == "yes": + self.one_yes_count += 1 + elif row.type == 2: + # positive correlation around 1.4, 1.8 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.two.add2(row.first_value, row.second_value) + elif row.type == 3: + # negative correlation around 11.8, 12.1 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.three.add2(row.first_value, row.second_value) + elif row.type == 4: + # positive correlation around 21.4, 23.4 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.four.add2(row.first_value, row.second_value) + elif row.type == 5: + test.assertIn(row.third_value, {"fish", "fowl"}) + test.assertIsNotNone(row.first_value) + test.assertIsNone(row.second_value) + if row.third_value == "fish": + # mean 8.1 and sd 0.755 + self.fish.add(row.first_value) + else: + # mean 11.2 and sd 1.114 + self.fowl.add(row.first_value) + + +class NullPartitionedTests(GeneratesDBTestCase): + """Testing null-partitioned grouped multivariate generation.""" + + dump_file_path = "eav.sql" + database_name = "eav" + schema_name = "public" + + def setUp(self) -> None: + """Set up the test with specific sample and suppress counts.""" + super().setUp() + NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 + NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the configure-generators object as our command.""" + return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + + def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: + gc.reset() + gc.do_propose("") + return gc.get_proposals() + + def test_create_with_null_partitioned_grouped_multivariate(self) -> None: + """Test EAV for all columns.""" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, + "measurement", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) + dist_to_choose = "null-partitioned grouped_multivariate_normal" + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + gc.reset() + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stats = EavMeasurementTableStats(conn, self.metadata, self) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 2 + self.assertAlmostEqual( + stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) + self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) + self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) + self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.08) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.08) + # type 3 + self.assertAlmostEqual( + stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) + # type 4 + self.assertAlmostEqual( + stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1.3) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.24, delta=0.6) + + def populate_measurement_type_vocab(self) -> None: + """Add a vocab table without messing around with files""" + table = self.metadata.tables["measurement_type"] + with self.sync_engine.connect() as conn: + conn.execute(insert(table).values({"id": 1, "name": "agreement"})) + conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) + conn.execute(insert(table).values({"id": 3, "name": "velocity"})) + conn.execute(insert(table).values({"id": 4, "name": "position"})) + conn.execute(insert(table).values({"id": 5, "name": "matter"})) + conn.commit() + + def merge_columns( + self, gc: TestGeneratorCmd, table: str, columns: list[str] + ) -> None: + """Merge columns in a table""" + gc.do_next(f"{table}.{columns[0]}") + for col in columns[1:]: + gc.do_merge(col) + gc.reset() + + def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: + """Test EAV for all columns with sampled and suppressed generation.""" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, + "measurement", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + self.assert_subset( + { + "null-partitioned grouped_multivariate_lognormal", + "null-partitioned grouped_multivariate_normal", + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + }, + set(proposals), + ) + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + self.merge_columns( + gc, + "observation", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stats = EavMeasurementTableStats(conn, self.metadata, self) + stmt = select(self.metadata.tables["observation"]) + rows = conn.execute(stmt).fetchall() + firsts = Stat() + for row in rows: + stats.types.add(row.type) + self.assertEqual(row.type, 1) + self.assertIsNotNone(row.first_value) + self.assertIsNone(row.second_value) + self.assertIn(row.third_value, {"ham", "eggs"}) + firsts.add(row.first_value) + self.assertEqual(firsts.count(), 800) + self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) + self.assert_subset(stats.types, {1, 2, 3, 4, 5}) + self.assertEqual(len(stats.types), 4) + self.assert_subset({1, 5}, stats.types) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.24, delta=0.6) + + def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: + """Test EAV for all columns with sampled generation but no suppression.""" + table_name = "measurement" + table2_name = "observation" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, table_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) + self.assertIn("null-partitioned grouped_multivariate_normal", proposals) + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + proposals, + ) + self.assertIn( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]", + proposals, + ) + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled]", proposals + ) + dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + self.merge_columns( + gc, table2_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + self.assert_subset({row.type for row in rows}, {1, 2, 3, 4, 5}) + stmt = select(self.metadata.tables[table2_name]) + rows = conn.execute(stmt).fetchall() + self.assertEqual( + {row.third_value for row in rows}, {"ham", "eggs", "cheese"} + ) + + def test_create_with_null_partitioned_grouped_sampled_tiny(self) -> None: + """ + Test EAV for all columns with sampled generation that only gets a tiny sample. + """ + # five will ensure that at least one group will have two elements in it, + # but all three cannot. + NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 + table_name = "observation" + generate_count = 100 + with self._get_cmd({}) as gc: + dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" + self.merge_columns( + gc, table_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + # we should only have one or two of "ham", "eggs" and "cheese" represented + foods = {row.third_value for row in rows} + self.assert_subset(foods, {"ham", "eggs", "cheese"}) + self.assertLess(len(foods), 3) diff --git a/tests/test_interactive_missingness.py b/tests/test_interactive_missingness.py new file mode 100644 index 00000000..7a63ea52 --- /dev/null +++ b/tests/test_interactive_missingness.py @@ -0,0 +1,100 @@ +""" Tests for the configure-missingness command. """ +import random +from collections.abc import MutableMapping +from typing import Any + +from sqlalchemy import select + +from datafaker.interactive import MissingnessCmd +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin + + +class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): + """MissingnessCmd but mocked""" + + +class ConfigureMissingnessTests(RequiresDBTestCase): + """Testing configure-missing.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + """We are using configure-missingness.""" + return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_set_missingness_to_sampled(self) -> None: + """Test that we can set one table to sampled missingness.""" + with self._get_cmd({}) as mc: + table = "signature_model" + mc.do_next(table) + mc.do_counts("") + self.assertSequenceEqual( + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] + ) + # Check the counts of NULLs in each column + self.assertSequenceEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) + mc.do_sampled("") + mc.do_quit("") + self.assertListEqual( + mc.config["tables"][table]["missingness_generators"], + [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' + }, + "name": "column_presence.sampled", + } + ], + ) + self.assertEqual( + mc.config["src-stats"][0]["name"], + "missing_auto__signature_model__0", + ) + self.assertEqual( + mc.config["src-stats"][0]["query"], + ( + "SELECT COUNT(*) AS row_count," + " player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null," + " based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000)" + " AS __t GROUP BY player_id__is_null, based_on__is_null" + ), + ) + + +class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): + """Testing configure-missing with generation.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_create_with_missingness(self) -> None: + """Test that we can sample real missingness and reproduce it.""" + random.seed(45) + # Configure the missingness + table_name = "signature_model" + with self._get_cmd({}) as mc: + mc.do_next(table_name) + mc.do_sampled("") + mc.do_quit("") + config = mc.config + self.generate_data(config, num_passes=100) + # Test that each missingness pattern is present in the database + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).mappings().fetchall() + patterns: set[int] = set() + for row in rows: + p = 0 if row["player_id"] is None else 1 + b = 0 if row["based_on"] is None else 2 + patterns.add(p + b) + # all pattern possibilities should be present + self.assertSetEqual(patterns, {0, 1, 2, 3}) diff --git a/tests/test_interactive_table.py b/tests/test_interactive_table.py new file mode 100644 index 00000000..04b157e7 --- /dev/null +++ b/tests/test_interactive_table.py @@ -0,0 +1,398 @@ +""" Tests for the configure-tables command. """ +from collections.abc import MutableMapping +from typing import Any + +from sqlalchemy import select + +from datafaker.interactive import TableCmd +from tests.utils import RequiresDBTestCase, TestDbCmdMixin + + +class TestTableCmd(TableCmd, TestDbCmdMixin): + """TableCmd but mocked""" + + +class ConfigureTablesTests(RequiresDBTestCase): + """Testing configure-tables.""" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: + return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) + + +class ConfigureTablesSrcTests(ConfigureTablesTests): + """Testing configure-tables with src.dump.""" + + dump_file_path = "src.dump" + database_name = "src" + schema_name = "public" + + def test_table_name_prompts(self) -> None: + """Test that the prompts follow the names of the tables.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as tc: + table_names = list(self.metadata.tables.keys()) + for t in table_names: + self.assertIn(t, tc.prompt) + tc.do_next("") + self.assertListEqual(tc.messages, [(TableCmd.INFO_NO_MORE_TABLES, (), {})]) + tc.reset() + for t in reversed(table_names): + self.assertIn(t, tc.prompt) + tc.do_previous("") + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] + ) + tc.reset() + bad_table_name = "notarealtable" + tc.do_next(bad_table_name) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] + ) + tc.reset() + good_table_name = table_names[2] + tc.do_next(good_table_name) + self.assertSequenceEqual(tc.messages, []) + self.assertIn(good_table_name, tc.prompt) + + def test_column_display(self) -> None: + """Test that we can see the names of the columns.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_columns("") + self.assertSequenceEqual( + tc.rows, + [ + ["id", "INTEGER", True, False, ""], + ["a", "BOOLEAN", False, False, ""], + ["b", "BOOLEAN", False, False, ""], + ["c", "TEXT", False, False, ""], + ], + ) + + def test_null_configuration(self) -> None: + """A table still works if its configuration is None.""" + config = { + "tables": None, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_private("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) + + def test_null_table_configuration(self) -> None: + """A table still works if its configuration is None.""" + config = { + "tables": { + "unique_constraint_test": None, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_private("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) + + def test_configure_tables(self) -> None: + """Test that we can change columns to ignore, vocab or generate.""" + config = { + "tables": { + "unique_constraint_test": { + "vocabulary_table": True, + }, + "no_pk_test": { + "ignore": True, + }, + "hospital_visit": { + "num_passes": 0, + }, + "empty_vocabulary": { + "private": True, + }, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_generate("") + tc.do_next("person") + tc.do_vocabulary("") + tc.do_next("mitigation_type") + tc.do_ignore("") + tc.do_next("hospital_visit") + tc.do_private("") + tc.do_quit("") + tc.do_next("empty_vocabulary") + tc.do_empty("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertFalse( + tables["unique_constraint_test"].get("primary_private", False) + ) + self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) + self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) + self.assertTrue(tables["no_pk_test"].get("ignore", False)) + self.assertFalse(tables["no_pk_test"].get("primary_private", False)) + self.assertEqual(tables["no_pk_test"].get("num_rows_per_pass", 1), 1) + self.assertTrue(tables["person"].get("vocabulary_table", False)) + self.assertFalse(tables["person"].get("ignore", False)) + self.assertFalse(tables["person"].get("primary_private", False)) + self.assertEqual(tables["person"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["mitigation_type"].get("vocabulary_table", False)) + self.assertTrue(tables["mitigation_type"].get("ignore", False)) + self.assertFalse(tables["mitigation_type"].get("primary_private", False)) + self.assertEqual(tables["mitigation_type"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["hospital_visit"].get("vocabulary_table", False)) + self.assertFalse(tables["hospital_visit"].get("ignore", False)) + self.assertTrue(tables["hospital_visit"].get("primary_private", False)) + self.assertEqual(tables["hospital_visit"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["empty_vocabulary"].get("vocabulary_table", False)) + self.assertFalse(tables["empty_vocabulary"].get("ignore", False)) + self.assertFalse(tables["empty_vocabulary"].get("primary_private", False)) + self.assertEqual(tables["empty_vocabulary"].get("num_rows_per_pass", 1), 0) + + def test_print_data(self) -> None: + """Test that we can print random rows from the table and random data from columns.""" + person_table = self.metadata.tables["person"] + with self.sync_engine.connect() as conn: + person_rows = conn.execute(select(person_table)).mappings().fetchall() + person_data = {row["person_id"]: row for row in person_rows} + name_set = {row["name"] for row in person_rows} + person_headings = ["person_id", "name", "research_opt_out", "stored_from"] + with self._get_cmd({}) as tc: + tc.do_next("person") + tc.do_data("") + self.assertSequenceEqual(tc.headings, person_headings) + self.assertEqual(len(tc.rows), 10) # default number of rows is 10 + for row in tc.rows: + expected = person_data[row[0]] + self.assertSequenceEqual(row, [expected[h] for h in person_headings]) + tc.reset() + rows_to_get_count = 6 + tc.do_data(str(rows_to_get_count)) + self.assertSequenceEqual(tc.headings, person_headings) + self.assertEqual(len(tc.rows), rows_to_get_count) + for row in tc.rows: + expected = person_data[row[0]] + self.assertSequenceEqual(row, [expected[h] for h in person_headings]) + tc.reset() + to_get_count = 12 + tc.do_data(f"{to_get_count} name") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual(len(tc.column_items[0]), to_get_count) + self.assertLessEqual(set(tc.column_items[0]), name_set) + tc.reset() + tc.do_data(f"{to_get_count} name 12") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual(len(tc.column_items[0]), to_get_count) + tc.reset() + tc.do_data(f"{to_get_count} name 13") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) + ) + tc.reset() + tc.do_data(f"{to_get_count} name 16") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) + ) + + def test_list_tables(self) -> None: + """Test that we can list the tables""" + config = { + "tables": { + "unique_constraint_test": { + "vocabulary_table": True, + }, + "no_pk_test": { + "ignore": True, + }, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_ignore("") + tc.do_next("person") + tc.do_vocabulary("") + tc.reset() + tc.do_tables("") + person_listed = False + unique_constraint_test_listed = False + no_pk_test_listed = False + for _text, args, _kwargs in tc.messages: + if args[2] == "person": + self.assertFalse(person_listed) + person_listed = True + self.assertEqual(args[0], "G") + self.assertEqual(args[1], "->V") + elif args[2] == "unique_constraint_test": + self.assertFalse(unique_constraint_test_listed) + unique_constraint_test_listed = True + self.assertEqual(args[0], "V") + self.assertEqual(args[1], "->I") + elif args[2] == "no_pk_test": + self.assertFalse(no_pk_test_listed) + no_pk_test_listed = True + self.assertEqual(args[0], "I") + self.assertEqual(args[1], " ") + else: + self.assertEqual(args[0], "G") + self.assertEqual(args[1], " ") + self.assertTrue(person_listed) + self.assertTrue(unique_constraint_test_listed) + self.assertTrue(no_pk_test_listed) + + +class ConfigureTablesInstrumentsTests(ConfigureTablesTests): + """Testing configure-tables with the instrument.sql database.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def test_sanity_checks_both(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with self._get_cmd(config) as tc: + tc.reset() + tc.do_quit("") + self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) + self.assertEqual( + tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[4], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) + + def test_sanity_checks_warnings_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + tc.do_next("manufacturer") + tc.do_vocabulary("") + tc.reset() + tc.do_quit("") + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("manufacturer", "ignore", "vocabulary"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) + + def test_sanity_checks_errors_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + tc.do_next("signature_model") + tc.do_empty("") + tc.reset() + tc.do_quit("") + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("signature_model", "generate", "empty"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) diff --git a/tests/test_main.py b/tests/test_main.py index c3652a2d..2167207f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -21,7 +21,13 @@ class TestCLI(DatafakerTestCase): @patch("datafaker.main.dict_to_metadata") @patch("datafaker.main.load_metadata_config") @patch("datafaker.main.create_db_vocab") - def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_meta: MagicMock, mock_config: MagicMock) -> None: + def test_create_vocab( + self, + mock_create: MagicMock, + mock_mdict: MagicMock, + mock_meta: MagicMock, + mock_config: MagicMock, + ) -> None: """Test the create-vocab sub-command.""" result = runner.invoke( app, @@ -31,7 +37,9 @@ def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_ catch_exceptions=False, ) - mock_create.assert_called_once_with(mock_meta.return_value, mock_mdict.return_value, mock_config.return_value) + mock_create.assert_called_once_with( + mock_meta.return_value, mock_mdict.return_value, mock_config.return_value + ) self.assertSuccess(result) @patch("datafaker.main.read_config_file") @@ -40,6 +48,7 @@ def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_ @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators( self, mock_require_stats: MagicMock, @@ -81,6 +90,7 @@ def test_create_generators( @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_uses_default_stats_file_if_necessary( self, mock_require_stats: MagicMock, @@ -143,6 +153,7 @@ def test_create_generators_errors_if_file_exists( @patch("datafaker.main.get_settings") @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_with_force_enabled( self, mock_make: MagicMock, @@ -159,10 +170,13 @@ def test_create_generators_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke(app, [ - "create-generators", - force_option, - ]) + result: Result = runner.invoke( + app, + [ + "create-generators", + force_option, + ], + ) mock_make.assert_called_once_with( mock_load_meta.return_value, @@ -183,7 +197,7 @@ def test_create_generators_with_force_enabled( @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") def test_create_tables( - self, mock_load_meta: MagicMock, mock_config: MagicMock, mock_create: MagicMock + self, mock_load_meta: MagicMock, _mock_config: MagicMock, mock_create: MagicMock ) -> None: """Test the create-tables sub-command.""" @@ -202,8 +216,11 @@ def test_create_tables( @patch("datafaker.main.logger") @patch("datafaker.main.import_file") @patch("datafaker.main.create_db_data") + @patch("datafaker.main.load_metadata_for_output") + # pylint: disable=too-many-arguments too-many-positional-arguments def test_create_data( self, + mock_load_metadata: MagicMock, mock_create: MagicMock, mock_import: MagicMock, mock_logger: MagicMock, @@ -225,9 +242,9 @@ def test_create_data( mock_create.assert_called_once_with( mock_tables.return_value, - mock_import.return_value.table_generator_dict, - mock_import.return_value.story_generator_list, + mock_import.return_value, 1, + mock_load_metadata.return_value, ) self.assertSuccess(result) @@ -335,11 +352,14 @@ def test_make_tables_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke(app, [ - "make-tables", - force_option, - "--orm-file=tests/examples/example_orm.yaml", - ]) + result: Result = runner.invoke( + app, + [ + "make-tables", + force_option, + "--orm-file=tests/examples/example_orm.yaml", + ], + ) mock_make_tables.assert_called_once_with( mock_get_settings.return_value.src_dsn, @@ -357,9 +377,11 @@ def test_make_tables_with_force_enabled( @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.main.get_settings") - @patch("datafaker.main.load_metadata", side_effect=["ms"]) def test_make_stats( - self, _lm: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock + self, + mock_get_settings: MagicMock, + mock_make: MagicMock, + mock_path: MagicMock, ) -> None: """Test the make-stats sub-command.""" example_conf_path = "tests/examples/example_config.yaml" @@ -379,7 +401,7 @@ def test_make_stats( self.assertSuccess(result) with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - mock_make.assert_called_once_with(get_test_settings().src_dsn, config, "ms", None) + mock_make.assert_called_once_with(get_test_settings().src_dsn, config, None) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" ) @@ -432,9 +454,11 @@ def test_make_stats_errors_if_no_src_dsn(self, mock_logger: MagicMock) -> None: @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.main.get_settings") - @patch("datafaker.main.load_metadata") def test_make_stats_with_force_enabled( - self, mock_meta: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock + self, + mock_get_settings: MagicMock, + mock_make_src_stats: MagicMock, + mock_path: MagicMock, ) -> None: """Tests that the make-stats command overwrite files when instructed.""" test_config_file: str = "tests/examples/example_config.yaml" @@ -445,7 +469,7 @@ def test_make_stats_with_force_enabled( test_settings: Settings = get_test_settings() mock_get_settings.return_value = test_settings make_test_output: dict = {"some_stat": 0} - mock_make.return_value = make_test_output + mock_make_src_stats.return_value = make_test_output for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): @@ -455,20 +479,21 @@ def test_make_stats_with_force_enabled( "make-stats", "--stats-file=stats_file.yaml", f"--config-file={test_config_file}", - "--orm-file=tests/examples/example_config.yaml", force_option, ], ) - mock_make.assert_called_once_with( - test_settings.src_dsn, config_file_content, mock_meta.return_value, None + mock_make_src_stats.assert_called_once_with( + test_settings.src_dsn, + config_file_content, + test_settings.src_schema, ) mock_path.return_value.write_text.assert_called_once_with( "some_stat: 0\n", encoding="utf-8" ) self.assertSuccess(result) - mock_make.reset_mock() + mock_make_src_stats.reset_mock() mock_path.reset_mock() def test_validate_config(self) -> None: @@ -507,7 +532,9 @@ def test_remove_data( catch_exceptions=False, ) self.assertEqual(0, result.exit_code) - mock_remove.assert_called_once_with(mock_meta.return_value, mock_config.return_value) + mock_remove.assert_called_once_with( + mock_meta.return_value, mock_config.return_value + ) @patch("datafaker.main.read_config_file") @patch("datafaker.main.remove_db_vocab") @@ -528,12 +555,18 @@ def test_remove_vocab( ) self.assertEqual(0, result.exit_code) mock_read_config.assert_called_once_with("config.yaml") - mock_remove.assert_called_once_with(mock_d2m.return_value, mock_load_metadata.return_value, mock_read_config.return_value) + mock_remove.assert_called_once_with( + mock_d2m.return_value, + mock_load_metadata.return_value, + mock_read_config.return_value, + ) @patch("datafaker.main.remove_db_tables") @patch("datafaker.main.load_metadata_for_output") @patch("datafaker.main.read_config_file") - def test_remove_tables(self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock) -> None: + def test_remove_tables( + self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock + ) -> None: """Test the remove-tables command.""" result = runner.invoke( app, diff --git a/tests/test_make.py b/tests/test_make.py index 6100aa77..49bb9e71 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -9,42 +9,43 @@ from sqlalchemy.dialects.mysql.types import INTEGER from sqlalchemy.dialects.postgresql import UUID -from datafaker.make import ( - _get_provider_for_column, - make_src_stats, -) -from tests.utils import RequiresDBTestCase, GeneratesDBTestCase +from datafaker.make import _get_provider_for_column, make_src_stats +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase class TestMakeGenerators(GeneratesDBTestCase): """Test the make_table_generators function.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def test_make_table_generators(self) -> None: - """ Check that we can make a generators file. """ + """Check that we can make a generators file.""" config = { "tables": { "player": { - "row_generators": [{ - "name": "dist_gen.constant", - "kwargs": { - "value": '"Cave"', + "row_generators": [ + { + "name": "dist_gen.constant", + "kwargs": { + "value": '"Cave"', + }, + "columns_assigned": "given_name", }, - "columns_assigned": "given_name", - }, { - "name": "dist_gen.constant", - "kwargs": { - "value": '"Johnson"', + { + "name": "dist_gen.constant", + "kwargs": { + "value": '"Johnson"', + }, + "columns_assigned": "family_name", }, - "columns_assigned": "family_name", - }], + ], }, }, } self.generate_data(config, num_passes=3) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = conn.execute(stmt).mappings().fetchall() for row in rows: @@ -96,7 +97,7 @@ def test_get_provider_for_column(self) -> None: ) self.assertEqual( generator_arguments, - { "length": "100" }, + {"length": "100"}, ) # UUID @@ -149,12 +150,15 @@ def check_make_stats_output(self, src_stats: dict) -> None: count_names = src_stats["count_names"]["results"] count_names.sort(key=lambda c: c["name"]) - self.assertListEqual(count_names, [ - {"num": 1, "name": "Miranda Rando-Generata"}, - {"num": 997, "name": "Randy Random"}, - {"num": 1, "name": "Testfried Testermann"}, - {"num": 1, "name": "Veronica Fyre"}, - ]) + self.assertListEqual( + count_names, + [ + {"num": 1, "name": "Miranda Rando-Generata"}, + {"num": 997, "name": "Randy Random"}, + {"num": 1, "name": "Testfried Testermann"}, + {"num": 1, "name": "Veronica Fyre"}, + ], + ) avg_person_id = src_stats["avg_person_id"]["results"] self.assertEqual(len(avg_person_id), 1) @@ -166,14 +170,14 @@ def check_make_stats_output(self, src_stats: dict) -> None: def test_make_stats_no_asyncio_schema(self) -> None: """Test that make_src_stats works when explicitly naming a schema.""" src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, self.config, self.metadata, self.schema_name) + make_src_stats(self.dsn, self.config, self.schema_name) ) self.check_make_stats_output(src_stats) def test_make_stats_no_asyncio(self) -> None: """Test that make_src_stats works using the example configuration.""" src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, self.config, self.metadata, self.schema_name) + make_src_stats(self.dsn, self.config, self.schema_name) ) self.check_make_stats_output(src_stats) @@ -185,7 +189,7 @@ def test_make_stats_asyncio(self) -> None: asyncio.set_event_loop(loop) config_asyncio = {**self.config, "use-asyncio": True} src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, config_asyncio, self.metadata, self.schema_name) + make_src_stats(self.dsn, config_asyncio, self.schema_name) ) self.check_make_stats_output(src_stats) @@ -216,7 +220,7 @@ def test_make_stats_empty_result(self, mock_logger: MagicMock) -> None: ] } src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, config, self.metadata, self.schema_name) + make_src_stats(self.dsn, config, self.schema_name) ) self.assertEqual(src_stats[query_name1]["results"], []) self.assertEqual(src_stats[query_name2]["results"], []) diff --git a/tests/test_noninteractive_generators.py b/tests/test_noninteractive_generators.py new file mode 100644 index 00000000..93431147 --- /dev/null +++ b/tests/test_noninteractive_generators.py @@ -0,0 +1,179 @@ +""" Tests for the configure-generators command with the --spec option. """ + +from collections.abc import Mapping, MutableMapping +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +from datafaker.interactive import update_config_generators +from tests.utils import RequiresDBTestCase + + +class NonInteractiveTests(RequiresDBTestCase): + """ + Test the --spec SPEC_FILE option of configure-generators + """ + + dump_file_path = "eav.sql" + database_name = "eav" + schema_name = "public" + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + ["observation", "type", "dist_gen.weighted_choice [sampled]"], + [ + "observation", + "first_value", + "dist_gen.weighted_choice", + "dist_gen.constant", + ], + [ + "observation", + "second_value", + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.constant", + ], + ["observation", "third_value", "dist_gen.weighted_choice"], + ] + ), + ) + def test_non_interactive_configure_generators( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set generators from a CSV file + """ + config: MutableMapping[str, Any] = {} + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens = { + f"{table}{sorted(rg['columns_assigned'])}": rg["name"] + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") + self.assertEqual( + row_gens["observation['first_value']"], "dist_gen.weighted_choice" + ) + self.assertEqual(row_gens["observation['second_value']"], "dist_gen.constant") + self.assertEqual( + row_gens["observation['third_value']"], "dist_gen.weighted_choice" + ) + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + [ + "observation", + "type first_value second_value third_value", + "null-partitioned grouped_multivariate_lognormal", + ], + ] + ), + ) + def test_non_interactive_configure_null_partitioned( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set multi-column generators from a CSV file + """ + config: MutableMapping[str, Any] = {} + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens = { + f"{table}{sorted(rg['columns_assigned'])}": rg + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["name"], + "dist_gen.alternatives", + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["name"], + '"with_constants_at"', + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], + '"grouped_multivariate_lognormal"', + ) + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + [ + "observation", + "type first_value second_value third_value", + "null-partitioned grouped_multivariate_lognormal", + ], + ] + ), + ) + def test_non_interactive_configure_null_partitioned_where_existing_merges( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set multi-column generators from a CSV file, + but where there are already multi-column generators configured + that will have to be unmerged. + """ + config = { + "tables": { + "observation": { + "row_generators": [ + { + "name": "arbitrary_gen", + "columns_assigned": [ + "type", + "second_value", + "first_value", + ], + } + ], + }, + }, + } + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens: Mapping[str, Any] = { + f"{table}{sorted(rg['columns_assigned'])}": rg + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["name"], + "dist_gen.alternatives", + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["name"], + '"with_constants_at"', + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], + '"grouped_multivariate_lognormal"', + ) diff --git a/tests/test_providers.py b/tests/test_providers.py index 9cc03c5e..cd880072 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,13 +1,12 @@ """Tests for the providers module.""" import datetime as dt -from pathlib import Path from typing import Any -from sqlalchemy import Column, Integer, Text, create_engine, insert +from sqlalchemy import Column, Integer, Text, insert from sqlalchemy.ext.declarative import declarative_base from datafaker import providers -from tests.utils import RequiresDBTestCase, DatafakerTestCase +from tests.utils import DatafakerTestCase, RequiresDBTestCase # pylint: disable=invalid-name Base = declarative_base() @@ -37,6 +36,7 @@ def test_bytes(self) -> None: class ColumnValueProviderTestCase(RequiresDBTestCase): """Tests for the ColumnValueProvider class.""" + dump_file_path = "providers.dump" def setUp(self) -> None: @@ -48,7 +48,7 @@ def test_column_value_present(self) -> None: """Test the key method.""" # pylint: disable=invalid-name - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = insert(Person).values(sex="M") conn.execute(stmt) @@ -60,7 +60,7 @@ def test_column_value_present(self) -> None: def test_column_value_missing(self) -> None: """Test the generator when there are no values in the source table.""" - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: provider: providers.ColumnValueProvider = providers.ColumnValueProvider() generated_value: Any = provider.column_value(connection, Person, "sex") diff --git a/tests/test_remove.py b/tests/test_remove.py index 660d6cb8..0d466db7 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -1,82 +1,90 @@ """Tests for the remove module.""" from unittest.mock import MagicMock, patch +from sqlalchemy import func, inspect, select +from sqlalchemy.engine import Connection + from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict from datafaker.settings import Settings -from sqlalchemy import func, inspect, select from tests.utils import RequiresDBTestCase class RemoveThingsTestCase(RequiresDBTestCase): - """ Tests for ``remove-`` commands. """ + """Tests for ``remove-`` commands.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" - def count_rows(self, connection, table_name: str) -> int | None: - return connection.execute(select( - func.count() - ).select_from( - self.metadata.tables[table_name] - )).scalar() + def count_rows(self, connection: Connection, table_name: str) -> int | None: + """Count the rows in a table.""" + return connection.execute( + # pylint: disable=not-callable. + select(func.count()).select_from(self.metadata.tables[table_name]) + ).scalar() @patch("datafaker.remove.get_settings") - def test_remove_data(self, mock_get_settings: MagicMock): + def test_remove_data(self, mock_get_settings: MagicMock) -> None: + """Test that data can be removed from non-vocabulary tables.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - remove_db_data(self.metadata, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) - with self.engine.connect() as conn: - self.assertGreater(self.count_rows(conn, "manufacturer"), 0) - self.assertGreater(self.count_rows(conn, "model"), 0) + remove_db_data( + self.metadata, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) + with self.sync_engine.connect() as conn: + self.assert_greater_and_not_none(self.count_rows(conn, "manufacturer"), 0) + self.assert_greater_and_not_none(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) self.assertEqual(self.count_rows(conn, "string"), 0) self.assertEqual(self.count_rows(conn, "signature_model"), 0) @patch("datafaker.remove.get_settings") def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-data raises if dst DSN is missing. """ + """Test that remove-data raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: - remove_db_data(self.metadata, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_data( + self.metadata, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) self.assertEqual( context_manager.exception.args[0], "Missing destination database settings" ) @patch("datafaker.remove.get_settings") - def test_remove_vocab(self, mock_get_settings: MagicMock): + def test_remove_vocab(self, mock_get_settings: MagicMock) -> None: + """Test that vocabulary tables can be removed.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) config = { "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, } } remove_db_data(self.metadata, config) remove_db_vocab(self.metadata, meta_dict, config) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: self.assertEqual(self.count_rows(conn, "manufacturer"), 0) self.assertEqual(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) @@ -85,46 +93,56 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-vocab raises if dst DSN is missing. """ + """Test that remove-vocab raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) - remove_db_vocab(self.metadata, meta_dict, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + meta_dict = metadata_to_dict( + self.metadata, self.schema_name, self.sync_engine + ) + remove_db_vocab( + self.metadata, + meta_dict, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) self.assertEqual( context_manager.exception.args[0], "Missing destination database settings" ) @patch("datafaker.remove.get_settings") - def test_remove_tables(self, mock_get_settings: MagicMock): + def test_remove_tables(self, mock_get_settings: MagicMock) -> None: + """Test that destination tables can be removed.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - self.assertTrue(inspect(self.engine).has_table("player")) + engine_in = inspect(self.engine) + assert engine_in is not None + assert hasattr(engine_in, "has_table") + self.assertTrue(engine_in.has_table("player")) remove_db_tables(self.metadata) - self.assertFalse(inspect(self.engine).has_table("manufacturer")) - self.assertFalse(inspect(self.engine).has_table("model")) - self.assertFalse(inspect(self.engine).has_table("player")) - self.assertFalse(inspect(self.engine).has_table("string")) - self.assertFalse(inspect(self.engine).has_table("signature_model")) + engine_out = inspect(self.engine) + assert engine_out is not None + assert hasattr(engine_out, "has_table") + self.assertFalse(engine_out.has_table("manufacturer")) + self.assertFalse(engine_out.has_table("model")) + self.assertFalse(engine_out.has_table("player")) + self.assertFalse(engine_out.has_table("string")) + self.assertFalse(engine_out.has_table("signature_model")) @patch("datafaker.remove.get_settings") def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-vocab raises if dst DSN is missing. """ + """Test that remove-vocab raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: remove_db_tables(self.metadata) diff --git a/tests/test_rst.py b/tests/test_rst.py index 090658c1..1a57ed61 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -2,11 +2,26 @@ The CLI does not allow errors to be disabled, but we can ignore them here.""" from pathlib import Path +from typing import Any from unittest import TestCase from restructuredtext_lint import lint_file +def _level_to_string(level: int) -> str: + """Get a string description of an error level.""" + return ["Severe", "Error", "Warning"][level] + + +def _error_message(lint_error: Any) -> str: + """Turn a linting error into an error message.""" + source = getattr(lint_error, "source") + line = getattr(lint_error, "line") + level = _level_to_string(getattr(lint_error, "level")) + message = getattr(lint_error, "full_message") + return f"{source}({line}): {level}: {message}" + + class RstTests(TestCase): """Linting for the doc .rst files.""" @@ -40,11 +55,8 @@ def test_dir(self) -> None: for file_error in file_errors # Only worry about ERRORs and WARNINGs if file_error.level <= 2 - if not any(filter(lambda m: m in file_error.full_message, allowed_errors)) + if not any(m in file_error.full_message for m in allowed_errors) ] if filtered_errors: - self.fail(msg="\n".join([ - f"{err.source}({err.line}): {["Severe", "Error", "Warning"][err.level]}: {err.full_message}" - for err in filtered_errors - ])) + self.fail(msg="\n".join(map(_error_message, filtered_errors))) diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 41a77474..afec078c 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -1,16 +1,7 @@ """Tests for the unique_generator module.""" -from pathlib import Path from unittest.mock import MagicMock -from sqlalchemy import ( - Boolean, - Column, - Integer, - Text, - UniqueConstraint, - create_engine, - insert, -) +from sqlalchemy import Boolean, Column, Integer, Text, UniqueConstraint, insert from sqlalchemy.ext.declarative import declarative_base from datafaker.unique_generator import UniqueGenerator @@ -40,6 +31,7 @@ class UniqueGeneratorTestCase(RequiresDBTestCase): and b which are boolean, and c which is a text column. There is a joint unique constraint on a and b, and a separate unique constraint on c. """ + dump_file_path = "unique_generator.dump" def setUp(self) -> None: @@ -54,7 +46,7 @@ def test_unique_generator_empty_table(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: # Find a couple of different values that could be inserted, then try to do # one duplicate. test_ab1 = [True, False] @@ -82,7 +74,7 @@ def test_unique_generator_nonempty_table(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: test_ab1 = [True, False] test_ab2 = [False, False] string1 = "String 1" @@ -108,7 +100,7 @@ def test_unique_generator_multivalue_generator(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: test_val1 = (True, False, "String 1") test_val2 = (True, False, "String 2") # Conflicts on (a, b) test_val3 = (False, False, "String 1") # Conflicts on c @@ -142,7 +134,7 @@ def test_unique_generator_max_tries(self) -> None: test_val = (True, False, "String 1") mock_generator.return_value = test_val - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: self.assertEqual(uniq_ab(conn, ["a", "b", "c"], mock_generator), test_val) self.assertRaises( RuntimeError, uniq_ab, conn, ["a", "b", "c"], mock_generator diff --git a/tests/test_utils.py b/tests/test_utils.py index 0eca2b11..ac82d124 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,10 @@ """Tests for the utils module.""" import os import sys +import tempfile +from importlib import resources from pathlib import Path -from unittest.mock import patch, MagicMock, call +from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert from sqlalchemy.orm import declarative_base @@ -13,7 +15,7 @@ import_file, read_config_file, ) -from tests.utils import RequiresDBTestCase, DatafakerTestCase +from tests.utils import DatafakerTestCase, RequiresDBTestCase # pylint: disable=invalid-name Base = declarative_base() @@ -60,7 +62,6 @@ class TestDownload(RequiresDBTestCase): dump_file_path = "providers.dump" mytable_file_path = Path("mytable.yaml") - test_dir = Path("tests/workspace") start_dir = os.getcwd() def setUp(self) -> None: @@ -69,6 +70,7 @@ def setUp(self) -> None: metadata.create_all(self.engine) + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) os.chdir(self.test_dir) self.mytable_file_path.unlink(missing_ok=True) @@ -81,15 +83,21 @@ def test_download_table(self) -> None: """Test the download_table function.""" # pylint: disable=protected-access - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(MyTable).values({"id": 1})) conn.commit() - download_table(MyTable.__table__, self.engine, self.mytable_file_path, compress=False) + download_table( + MyTable.__table__, self.sync_engine, self.mytable_file_path, compress=False + ) # The .strip() gets rid of any possible empty lines at the end of the file. - with Path("../examples/expected.yaml").open(encoding="utf-8") as yamlfile: - expected = yamlfile.read().strip() + tests_module = sys.modules["tests"] + with resources.as_file( + resources.files(tests_module) / "examples" / "expected.yaml" + ) as yamlpath: + with yamlpath.open(encoding="utf-8") as yamlfile: + expected = yamlfile.read().strip() with self.mytable_file_path.open(encoding="utf-8") as yamlfile: actual = yamlfile.read().strip() @@ -108,124 +116,220 @@ def test_warns_of_invalid_config(self) -> None: "The config file is invalid: %s", "'a' is not of type 'integer'" ) + class TestUtils(DatafakerTestCase): - """ Miscellaneous tests. """ + """Miscellaneous tests.""" + def test_generators_require_stats(self) -> None: - """ Test that we can tell if a configuration requires SRC_STATS or not. """ - self.assertTrue(generators_require_stats({ - "object_instantiation": { - "mygen": {"name": "MyGen", "kwargs": {"a": '1 + SRC_STATS["my"]["results"][0]'}} - } - })) - self.assertTrue(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "kwargs": {"a": '[None] + SRC_STATS["my"]["results"]'}, - }] - })) - self.assertTrue(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "args": ['(SRC_STATS["my"]["results"])'], - }] - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "missingness_generators":[{ - "name": "msg", - "kwargs": {"a": '[SRC_STATS["my"], SRC_STATS["theirs"]]'}, - "columns_assigned": ["a"], - }] + """Test that we can tell if a configuration requires SRC_STATS or not.""" + self.assertTrue( + generators_require_stats( + { + "object_instantiation": { + "mygen": { + "name": "MyGen", + "kwargs": {"a": '1 + SRC_STATS["my"]["results"][0]'}, + } + } } - } - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "kwargs": {"a": 'SRC_STATS["ifu"]["results"]'}, - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "kwargs": {"a": '[None] + SRC_STATS["my"]["results"]'}, + } + ] } - } - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['SRC_STATS'], - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "args": ['(SRC_STATS["my"]["results"])'], + } + ] } - } - })) - self.assertFalse(generators_require_stats({ - "object_instantiation": { - "mygen": {"name": "MyGen", "kwargs": {"a": 1}} - } - })) - self.assertFalse(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "kwargs": {"a": '[None]'}, - }] - })) - self.assertFalse(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "args": ['(SRC_STATS_["my"]["results"])'], - }] - })) - self.assertFalse(generators_require_stats({ - "missingness_generators": [{ - "name": "msg", - "kwargs": {"a": '"SRC_STATS"'}, - "columns_assigned": ["a"], - }] - })) - self.assertFalse(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "kwargs": {"a": 'SRC_STAT["ifu"]["results"]'}, - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "missingness_generators": [ + { + "name": "msg", + "kwargs": { + "a": '[SRC_STATS["my"], SRC_STATS["theirs"]]' + }, + "columns_assigned": ["a"], + } + ] + } + } } - } - })) - self.assertFalse(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['SRC_STATSS'], - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "kwargs": {"a": 'SRC_STATS["ifu"]["results"]'}, + "columns_assigned": ["a"], + } + ] + } + } } - } - })) + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["SRC_STATS"], + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "object_instantiation": { + "mygen": {"name": "MyGen", "kwargs": {"a": 1}} + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "kwargs": {"a": "[None]"}, + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "args": ['(SRC_STATS_["my"]["results"])'], + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "missingness_generators": [ + { + "name": "msg", + "kwargs": {"a": '"SRC_STATS"'}, + "columns_assigned": ["a"], + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "kwargs": {"a": 'SRC_STAT["ifu"]["results"]'}, + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["SRC_STATSS"], + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) @patch("datafaker.utils.logger") - def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): - generators_require_stats({ - "story_generators": [ - {"name": "my_story_gen", "kwargs": {"b": "'unclosed"}} - ], - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['1 2'], - "columns_assigned": ["a"], - }] - } + def test_testing_generators_finds_syntax_errors(self, logger: MagicMock) -> None: + """Test that looking for ``SRC_STATS`` references finds Python syntax errors.""" + generators_require_stats( + { + "story_generators": [ + {"name": "my_story_gen", "kwargs": {"b": "'unclosed"}} + ], + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["1 2"], + "columns_assigned": ["a"], + } + ] + } + }, } - }) - logger.error.assert_has_calls([ - call("Syntax error in argument %s of %s: %s\n%s\n%s", "b", "story_generators[0]", "unterminated string literal (detected at line 1)", "'unclosed", " ^"), - call("Syntax error in argument %d of %s: %s\n%s\n%s", 1, "tables.things.row_generators[0]", "invalid syntax", "1 2", " ^"), - ]) + ) + logger.error.assert_has_calls( + [ + call( + "Syntax error in argument %s of %s: %s\n%s%s", + "b", + "story_generators[0]", + "unterminated string literal (detected at line 1)", + "'unclosed", + "\n ^", + ), + call( + "Syntax error in argument %d of %s: %s\n%s%s", + 1, + "tables.things.row_generators[0]", + "invalid syntax", + "1 2", + "\n ^", + ), + ] + ) diff --git a/tests/utils.py b/tests/utils.py index 08850f85..2810c41c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,25 +1,33 @@ """Utilities for testing.""" import asyncio -from functools import lru_cache import os -from pathlib import Path import shutil -from sqlalchemy.schema import MetaData -from subprocess import run -import testing.postgresql import traceback -from typing import Any +from collections.abc import MutableSequence, Sequence +from functools import lru_cache +from pathlib import Path +from subprocess import run +from tempfile import mkstemp +from typing import Any, Mapping from unittest import TestCase, skipUnless -import yaml -from sqlalchemy import MetaData -from tempfile import mkstemp +import testing.postgresql +import yaml +from sqlalchemy.schema import MetaData from datafaker import settings from datafaker.create import create_db_data_into -from datafaker.make import make_tables_file, make_src_stats, make_table_generators +from datafaker.interactive.base import DbCmd +from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import import_file, sorted_non_vocabulary_tables, create_db_engine +from datafaker.utils import ( + T, + create_db_engine, + get_sync_engine, + import_file, + sorted_non_vocabulary_tables, +) + class SysExit(Exception): """To force the function to exit as sys.exit() would.""" @@ -46,7 +54,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.maxDiff = None # pylint: disable=invalid-name super().__init__(*args, **kwargs) - def setUp(self): + def setUp(self) -> None: settings.get_settings.cache_clear() def assertReturnCode( # pylint: disable=invalid-name @@ -68,38 +76,47 @@ def assertFailure(self, result: Any) -> None: # pylint: disable=invalid-name self.assertReturnCode(result, 1) def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-name - """ Assert that the result has no exception. """ + """Assert that the result has no exception.""" if result.exception is None: return - self.fail(''.join(traceback.format_exception(result.exception))) + self.fail("".join(traceback.format_exception(result.exception))) + + def assert_greater_and_not_none(self, left: float | None, right: float) -> None: + """ + Assert left is not None and greater than right + """ + if left is None: + self.fail("first argument is None") + else: + self.assertGreater(left, right) - def assertSubset(self, set1, set2, msg=None): + def assert_subset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: """Assert a set is a (non-strict) subset. - Args: - set1: The asserted subset. - set2: The asserted superset. - msg: Optional message to use on failure instead of a list of - differences. + :param set1: The asserted subset. + :param set2: The asserted superset. + :param msg: Optional message to use on failure instead of a list of + differences. """ try: difference = set1.difference(set2) except TypeError as e: - self.fail('invalid type when attempting set difference: %s' % e) + self.fail(f"invalid type when attempting set difference: {e}") except AttributeError as e: - self.fail('first argument does not support set difference: %s' % e) + self.fail(f"first argument does not support set difference: {e}") if not difference: return lines = [] if difference: - lines.append('Items in the first set but not the second:') + lines.append("Items in the first set but not the second:") for item in difference: lines.append(repr(item)) - standardMsg = '\n'.join(lines) - self.fail(self._formatMessage(msg, standardMsg)) + standard_msg = "\n".join(lines) + self.fail(self._formatMessage(msg, standard_msg)) + @skipUnless(shutil.which("psql"), "need to find 'psql': install PostgreSQL to enable") class RequiresDBTestCase(DatafakerTestCase): @@ -112,24 +129,27 @@ class RequiresDBTestCase(DatafakerTestCase): to get an engine to access the database and self.metadata to get metadata reflected from that engine. """ - schema_name = None + + schema_name: str | None = None use_asyncio = False - examples_dir = "tests/examples" - dump_file_path = None - database_name = None + examples_dir = Path("tests/examples") + dump_file_path: str | None = None + database_name: str | None = None Postgresql = None @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.Postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True) @classmethod - def tearDownClass(cls): - cls.Postgresql.clear_cache() + def tearDownClass(cls) -> None: + if cls.Postgresql is not None: + cls.Postgresql.clear_cache() def setUp(self) -> None: super().setUp() - self.postgresql = self.Postgresql() + assert self.Postgresql is not None + self.postgresql = self.Postgresql() # pylint: disable=not-callable if self.dump_file_path is not None: self.run_psql(Path(self.examples_dir) / Path(self.dump_file_path)) self.engine = create_db_engine( @@ -137,18 +157,23 @@ def setUp(self) -> None: schema_name=self.schema_name, use_asyncio=self.use_asyncio, ) + self.sync_engine = get_sync_engine(self.engine) self.metadata = MetaData() - self.metadata.reflect(self.engine) + self.metadata.reflect(self.sync_engine) def tearDown(self) -> None: self.postgresql.stop() super().tearDown() @property - def dsn(self): + def dsn(self) -> str: + """Get the database connection string.""" if self.database_name: - return self.postgresql.url(database=self.database_name) - return self.postgresql.url() + url = self.postgresql.url(database=self.database_name) + else: + url = self.postgresql.url() + assert isinstance(url, str) + return url def run_psql(self, dump_file: Path) -> None: """Run psql and pass dump_file_name as the --file option.""" @@ -176,38 +201,51 @@ def run_psql(self, dump_file: Path) -> None: class GeneratesDBTestCase(RequiresDBTestCase): + """A test case for which a database is generated.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialise a GeneratedDB test case.""" + super().__init__(*args, **kwargs) + self.generators_file_path = "" + self.stats_fd = 0 + self.stats_file_path = "" + self.config_file_path = "" + self.config_fd = 0 + def setUp(self) -> None: + """Set up the test case with an actual orm.yaml file.""" super().setUp() # Generate the `orm.yaml` from the database (self.orm_fd, self.orm_file_path) = mkstemp(".yaml", "orm_", text=True) with os.fdopen(self.orm_fd, "w", encoding="utf-8") as orm_fh: orm_fh.write(make_tables_file(self.dsn, self.schema_name, {})) - def set_configuration(self, config) -> None: - """ - Accepts a configuration file, writes it out. - """ + def set_configuration(self, config: Mapping[str, Any]) -> None: + """Accepts a configuration file, writes it out.""" (self.config_fd, self.config_file_path) = mkstemp(".yaml", "config_", text=True) with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) - def get_src_stats(self, config) -> dict[str, any]: + def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: """ - Runs `make-stats` producing `src-stats.yaml` + Runs `make-stats` producing `src-stats.yaml`. + :return: Python dictionary representation of the contents of the src-stats file """ loop = asyncio.new_event_loop() src_stats = loop.run_until_complete( - make_src_stats(self.dsn, config, self.metadata, self.schema_name) + make_src_stats(self.dsn, config, self.schema_name) ) loop.close() - (self.stats_fd, self.stats_file_path) = mkstemp(".yaml", "src_stats_", text=True) + (self.stats_fd, self.stats_file_path) = mkstemp( + ".yaml", "src_stats_", text=True + ) with os.fdopen(self.stats_fd, "w", encoding="utf-8") as stats_fh: stats_fh.write(yaml.dump(src_stats)) return src_stats - def create_generators(self, config) -> None: - """ ``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py`` """ + def create_generators(self, config: Mapping[str, Any]) -> None: + """``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py``""" datafaker_content = make_table_generators( self.metadata, config, @@ -219,27 +257,27 @@ def create_generators(self, config) -> None: with os.fdopen(generators_fd, "w", encoding="utf-8") as datafaker_fh: datafaker_fh.write(datafaker_content) - def remove_data(self, config): - """ Remove source data from the DB. """ + def remove_data(self, config: Mapping[str, Any]) -> None: + """Remove source data from the DB.""" # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) - def create_data(self, config, num_passes=1): - """ Create fake data in the DB. """ + def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: + """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) - table_generator_dict = datafaker_module.table_generator_dict - story_generator_list = datafaker_module.story_generator_list create_db_data_into( sorted_non_vocabulary_tables(self.metadata, config), - table_generator_dict, - story_generator_list, + datafaker_module, num_passes, self.dsn, self.schema_name, + self.metadata, ) - def generate_data(self, config, num_passes=1): + def generate_data( + self, config: Mapping[str, Any], num_passes: int = 1 + ) -> Mapping[str, Any]: """ Replaces the DB's source data with generated data. :return: A Python dictionary representation of the src-stats.yaml file, for what it's worth. @@ -250,3 +288,45 @@ def generate_data(self, config, num_passes=1): self.remove_data(config) self.create_data(config, num_passes) return src_stats + + +class TestDbCmdMixin(DbCmd): + """A mixin for capturing output from interactive commands.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a TestDbCmdMixin""" + super().__init__(*args, **kwargs) + self.reset() + + def reset(self) -> None: + """Reset all the debug messages collected so far.""" + self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] + self.headings: Sequence[str] = [] + self.rows: Sequence[Sequence[str]] = [] + self.column_items: MutableSequence[Sequence[str]] = [] + self.columns: Mapping[str, Sequence[Any]] = {} + + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Capture the printed message.""" + self.messages.append((text, args, kwargs)) + + def print_table( + self, headings: Sequence[str], rows: Sequence[Sequence[str]] + ) -> None: + """Capture the printed table.""" + self.headings = headings + self.rows = rows + + def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: + """Capture the printed table.""" + self.columns = columns + + # pylint: disable=arguments-renamed + def columnize(self, items: Sequence[str] | None, _displaywidth: int = 80) -> None: + """Capture the printed table.""" + if items is not None: + self.column_items.append(items) + + def ask_save(self) -> str: + """Quitting always works without needing to ask the user.""" + return "yes" diff --git a/tests/workspace/.gitignore b/tests/workspace/.gitignore deleted file mode 100644 index 72e8ffc0..00000000 --- a/tests/workspace/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/tests/workspace/README.md b/tests/workspace/README.md deleted file mode 100644 index 8165a69a..00000000 --- a/tests/workspace/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Test Workspace - -A workspace for the functional tests to run in.