From 35da209a3c899f0ed962c8b8d59b1d400cd9a359 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 22:06:07 +0100 Subject: [PATCH 01/35] Automatic pre-commit fixes --- datafaker/base.py | 54 +- datafaker/create.py | 30 +- datafaker/dump.py | 21 +- datafaker/generators.py | 601 ++++-- datafaker/interactive.py | 559 +++-- datafaker/main.py | 110 +- datafaker/make.py | 110 +- datafaker/providers.py | 8 +- datafaker/remove.py | 21 +- datafaker/serialize_metadata.py | 73 +- datafaker/utils.py | 109 +- docs/source/_static/config_schema.html | 2760 +++++++++++++++++++++++- docs/source/custom_generators.rst | 2 +- docs/source/introduction.rst | 2 +- docs/source/quickstart.rst | 6 +- tests/test_create.py | 19 +- tests/test_dump.py | 23 +- tests/test_functional.py | 28 +- tests/test_interactive.py | 627 ++++-- tests/test_main.py | 71 +- tests/test_make.py | 52 +- tests/test_providers.py | 3 +- tests/test_remove.py | 70 +- tests/test_unique_generator.py | 1 + tests/test_utils.py | 323 ++- tests/utils.py | 43 +- 26 files changed, 4627 insertions(+), 1099 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 17f471b..56315a4 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -1,35 +1,34 @@ """Base table generator classes.""" -from abc import ABC, abstractmethod -from collections.abc import Callable -from dataclasses import dataclass import functools +import gzip import math -import numpy as np import os -from pathlib import Path import random +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path from typing import Any +import numpy as np import yaml -import gzip from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.schema import Table from datafaker.utils import ( + MAKE_VOCAB_PROGRESS_REPORT_EVERY, logger, stream_yaml, - MAKE_VOCAB_PROGRESS_REPORT_EVERY, table_row_count, ) + @functools.cache def zipf_weights(size): - total = sum(map(lambda n: 1/n, range(1, size + 1))) - return [ - 1 / (n * total) - for n in range(1, size + 1) - ] + total = sum(map(lambda n: 1 / n, range(1, size + 1))) + return [1 / (n * total) for n in range(1, size + 1)] + def merge_with_constants(xs: list, constants_at: dict[int, any]): """ @@ -118,10 +117,7 @@ def multivariate_normal_np(self, cov): rank = int(cov["rank"]) if rank == 0: return np.empty(shape=(0,)) - mean = [ - float(cov[f"m{i}"]) - for i in range(rank) - ] + mean = [float(cov[f"m{i}"]) for i in range(rank)] covs = [ [ float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"]) @@ -138,7 +134,9 @@ def _select_group(self, alts: list[dict[str, any]]): total = 0 for alt in alts: if alt["count"] < 0: - logger.warning("Alternative count is %d, but should not be negative", alt["count"]) + logger.warning( + "Alternative count is %d, but should not be negative", alt["count"] + ) else: total += alt["count"] if total == 0: @@ -218,7 +216,9 @@ def _check_generator_name(self, name: str) -> None: if name not in self.PERMITTED_SUBGENS: raise Exception("%s is not a permitted generator", name) - def alternatives(self, alternative_configs: list[dict[str, any]], counts: list[int] | None): + def alternatives( + self, alternative_configs: list[dict[str, any]], counts: list[int] | None + ): """ A generator that picks between other generators. @@ -245,7 +245,9 @@ def alternatives(self, alternative_configs: list[dict[str, any]], counts: list[i self._check_generator_name(name) return getattr(self, name)(**alt["params"]) - def with_constants_at(self, constants_at: list[int], subgen: str, params: dict[str, any]): + def with_constants_at( + self, constants_at: list[int], subgen: str, params: dict[str, any] + ): if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -257,7 +259,7 @@ def with_constants_at(self, constants_at: list[int], subgen: str, params: dict[s return list(merge_with_constants(subout, constants_at)) def truncated_string(self, subgen_fn, params, length): - """ Calls ``subgen_fn(**params)`` and truncates the results to ``length``. """ + """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: return None @@ -288,7 +290,9 @@ class FileUploader: table: Table - def _load_existing_file(self, connection: Connection, file_size: int, opener: Callable[[], Any]) -> None: + def _load_existing_file( + self, connection: Connection, file_size: int, opener: Callable[[], Any] + ) -> None: count = 0 with opener() as fh: rows = stream_yaml(fh) @@ -305,7 +309,7 @@ def _load_existing_file(self, connection: Connection, file_size: int, opener: Ca 100 * fh.tell() / file_size, ) - def load(self, connection: Connection, base_path: Path=Path(".")) -> None: + def load(self, connection: Connection, base_path: Path = Path(".")) -> None: """Load the data from file.""" yaml_file = base_path / Path(self.table.fullname + ".yaml") if yaml_file.exists(): @@ -318,7 +322,10 @@ def load(self, connection: Connection, base_path: Path=Path(".")) -> None: logger.warning("File %s not found. Skipping...", yaml_file) return if 0 < table_row_count(self.table, connection): - logger.warning("Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", self.table.name) + logger.warning( + "Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", + self.table.name, + ) return try: file_size = os.path.getsize(yaml_file) @@ -331,6 +338,7 @@ def load(self, connection: Connection, base_path: Path=Path(".")) -> None: "Error inserting rows into table %s: %s", self.table.fullname, e ) + class ColumnPresence: def sampled(self, patterns): total = 0 diff --git a/datafaker/create.py b/datafaker/create.py index d84eadb..f522876 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,17 +1,14 @@ """Functions and classes to create and populate the target database.""" -from collections import Counter import pathlib import random +from collections import Counter from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple from sqlalchemy import Connection, insert, inspect from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from sqlalchemy.schema import ( - CreateSchema, - MetaData, - Table, -) +from sqlalchemy.schema import CreateSchema, MetaData, Table + from datafaker.base import FileUploader, TableGenerator from datafaker.settings import get_settings from datafaker.utils import ( @@ -56,11 +53,11 @@ def create_db_vocab( metadata: MetaData, meta_dict: dict[str, Any], config: Mapping, - base_path: pathlib.Path | None=pathlib.Path(".") + base_path: pathlib.Path | None = pathlib.Path("."), ) -> int: """ Load vocabulary tables from files. - + arguments: metadata: The schema of the database meta_dict: The simple description of the schema from --orm-file @@ -85,7 +82,7 @@ def create_db_vocab( uploader = FileUploader(table=vocab_table) with Session(dst_engine) as session: session.begin() - uploader.load(session.connection(), base_path = base_path) + uploader.load(session.connection(), base_path=base_path) session.commit() tables_loaded.append(vocab_table_name) except IntegrityError: @@ -128,9 +125,7 @@ def create_db_data_into( db_dsn: str, schema_name: str | None, ) -> RowCounts: - dst_engine = get_sync_engine( - create_db_engine(db_dsn, schema_name=schema_name) - ) + dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) row_counts: Counter[str] = Counter() with dst_engine.connect() as dst_conn: @@ -145,7 +140,8 @@ def create_db_data_into( class StoryIterator: - def __init__(self, + def __init__( + self, stories: Iterable[tuple[str, Story]], table_dict: Mapping[str, Table], table_generator_dict: Mapping[str, TableGenerator], @@ -219,7 +215,9 @@ def next(self) -> None: self._table_name, self._provided_values = next(self._story) return else: - self._table_name, self._provided_values = self._story.send(self._final_values) + self._table_name, self._provided_values = self._story.send( + self._final_values + ) return except StopIteration: try: @@ -274,7 +272,9 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values(table_generator(dst_conn, random.random)) + stmt = insert(table).values( + table_generator(dst_conn, random.random) + ) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() diff --git a/datafaker/dump.py b/datafaker/dump.py index c4d2b24..36ca046 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,14 +1,11 @@ import csv import io + import sqlalchemy from sqlalchemy.schema import MetaData from datafaker.settings import get_settings -from datafaker.utils import ( - create_db_engine, - get_sync_engine, - logger, -) +from datafaker.utils import create_db_engine, get_sync_engine, logger def _make_csv_writer(file): @@ -16,14 +13,14 @@ def _make_csv_writer(file): def dump_db_tables( - metadata: MetaData, - dsn: str, - schema: str | None, - table_name: str, - file: io.TextIOBase + metadata: MetaData, + dsn: str, + schema: str | None, + table_name: str, + file: io.TextIOBase, ) -> None: - """ Output the table as CSV. """ - if table_name not in metadata.tables: + """Output the table as CSV.""" + if table_name not in metadata.tables: logger.error("%s is not a table described in the ORM file", table_name) return table = metadata.tables[table_name] diff --git a/datafaker/generators.py b/datafaker/generators.py index dd6e818..1ea0a8f 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -2,20 +2,21 @@ Generator factories for making generators for single columns. """ +import decimal +import math +import re from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass -import decimal from functools import lru_cache from itertools import chain, combinations -import math +from typing import Callable, Iterable, TypeVar + import mimesis import mimesis.locales -import re import sqlalchemy -from sqlalchemy import Column, Engine, text, Connection, RowMapping, Sequence +from sqlalchemy import Column, Connection, Engine, RowMapping, Sequence, text from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time -from typing import Callable, Iterable, TypeVar from datafaker.base import DistributionGenerator from datafaker.utils import logger @@ -27,6 +28,7 @@ dist_gen = DistributionGenerator() generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + class Generator(ABC): """ Random data generator. @@ -40,9 +42,10 @@ class Generator(ABC): It also knows these summary statistics for the column it was instantiated on, and therefore knows how to generate fake data for that column. """ + @abstractmethod def function_name(self) -> str: - """ The name of the generator function to put into df.py. """ + """The name of the generator function to put into df.py.""" def name(self) -> str: """ @@ -60,7 +63,7 @@ def nominal_kwargs(self) -> dict[str, str]: The values will tend to be references to something in the src-stats.yaml file. For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will - provide the value stored in src-stats.yaml as + provide the value stored in src-stats.yaml as SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument to the generator function. """ @@ -130,6 +133,7 @@ class PredefinedGenerator(Generator): """ Generator built from an existing config.yaml. """ + SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') @@ -152,13 +156,22 @@ def _get_src_stats_mentioned(self, val) -> set[str]: return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) return set() - def __init__(self, table_name: str, generator_object: Mapping[str, any], config: Mapping[str, any]): + def __init__( + self, + table_name: str, + generator_object: Mapping[str, any], + config: Mapping[str, any], + ): """ Initialise a generator from a config.yaml. :param config: The entire configuration. :param generator_object: The part of the configuration at tables.*.row_generators """ - logger.debug("Creating a PredefinedGenerator %s from table %s", generator_object["name"], table_name) + logger.debug( + "Creating a PredefinedGenerator %s from table %s", + generator_object["name"], + table_name, + ) self._table_name = table_name self._name: str = generator_object["name"] self._kwn: dict[str, str] = generator_object.get("kwargs", {}) @@ -170,7 +183,9 @@ def __init__(self, table_name: str, generator_object: Mapping[str, any], config: for sstat in config.get("src-stats", []): name: str = sstat["name"] dpq = sstat.get("dp-query", None) - query = sstat.get("query", dpq) #... should we really be combining query and dp-query? + query = sstat.get( + "query", dpq + ) # ... should we really be combining query and dp-query? comments = sstat.get("comments", []) if name in self._src_stats_mentioned: logger.debug("Found a src-stats entry for %s", name) @@ -181,7 +196,7 @@ def __init__(self, table_name: str, generator_object: Mapping[str, any], config: # name is auto__{table_name}, so it's a select_aggregate, so we split up its clauses sacs = [ self.AS_CLAUSE_RE.match(clause) - for clause in sam.group(1).split(',') + for clause in sam.group(1).split(",") ] # Work out what select_aggregate_clauses this represents for sac in sacs: @@ -213,13 +228,13 @@ def custom_queries(self) -> dict[str, dict[str, str]]: def actual_kwargs(self) -> dict[str, any]: # Run the queries from nominal_kwargs - #... + # ... logger.error("PredefinedGenerator.actual_kwargs not implemented yet") return {} def generate_data(self, count) -> list[any]: # Call the function if we can. This could be tricky... - #... + # ... logger.error("PredefinedGenerator.generate_data not implemented yet") return [] @@ -227,7 +242,8 @@ def generate_data(self, count) -> list[any]: class GeneratorFactory(ABC): """ A factory for making generators appropriate for a database column. - """ + """ + @abstractmethod def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: """ @@ -240,13 +256,27 @@ class Buckets: Finds the real distribution of continuous data so that we can measure the fit of generators against it. """ - def __init__(self, engine: Engine, table_name: str, column_name: str, mean:float, stddev: float, count: int): + + def __init__( + self, + engine: Engine, + table_name: str, + column_name: str, + mean: float, + stddev: float, + count: int, + ): with engine.connect() as connection: - raw_buckets = connection.execute(text( - "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( - column=column_name, table=table_name, x=mean - 2 * stddev, w = stddev / 2 + raw_buckets = connection.execute( + text( + "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( + column=column_name, + table=table_name, + x=mean - 2 * stddev, + w=stddev / 2, + ) ) - )) + ) self.buckets = [0] * 10 for rb in raw_buckets: if rb.b is not None: @@ -259,7 +289,7 @@ def __init__(self, engine: Engine, table_name: str, column_name: str, mean:float def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): """ Construct a Buckets object. - + Calculates the mean and standard deviation of the values in the column specified and makes ten buckets, centered on the mean and each half a standard deviation wide (except for the end two that extend to @@ -268,10 +298,12 @@ def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): """ with engine.connect() as connection: result = connection.execute( - text("SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( - table=table_name, - column=column_name, - )) + text( + "SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( + table=table_name, + column=column_name, + ) + ) ).first() if result is None or result.stddev is None or result.count < 2: return None @@ -303,13 +335,14 @@ def fit_from_values(self, values: list[float]) -> float: x = self.mean - 2 * self.stddev w = self.stddev / 2 for v in values: - b = min(9, max(0, int((v - x)/w))) + b = min(9, max(0, int((v - x) / w))) buckets[b] += 1 return self.fit_from_counts(buckets) class MultiGeneratorFactory(GeneratorFactory): - """ A composite factory. """ + """A composite factory.""" + def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories @@ -336,27 +369,30 @@ def __init__( f = generic for part in function_name.split("."): if not hasattr(f, part): - raise Exception(f"Mimesis does not have a function {function_name}: {part} not found") + raise Exception( + f"Mimesis does not have a function {function_name}: {part} not found" + ) f = getattr(f, part) if not callable(f): - raise Exception(f"Mimesis object {function_name} is not a callable, so cannot be used as a generator") + raise Exception( + f"Mimesis object {function_name} is not a callable, so cannot be used as a generator" + ) self._name = "generic." + function_name self._generator_function = f + def function_name(self): return self._name + def generate_data(self, count): - return [ - self._generator_function() - for _ in range(count) - ] + return [self._generator_function() for _ in range(count)] class MimesisGenerator(MimesisGeneratorBase): def __init__( self, function_name: str, - value_fn: Callable[[any], float] | None=None, - buckets: Buckets | None=None, + value_fn: Callable[[any], float] | None = None, + buckets: Buckets | None = None, ): """ Generator from Mimesis. @@ -373,17 +409,18 @@ def __init__( return samples = self.generate_data(400) if value_fn: - samples = [ - value_fn(s) - for s in samples - ] + samples = [value_fn(s) for s in samples] self._fit = buckets.fit_from_values(samples) + def function_name(self): return self._name + def nominal_kwargs(self): return {} + def actual_kwargs(self): return {} + def fit(self, default=None): return default if self._fit is None else self._fit @@ -393,36 +430,46 @@ def __init__( self, function_name: str, length: int, - value_fn: Callable[[any], float] | None=None, - buckets: Buckets | None=None, + value_fn: Callable[[any], float] | None = None, + buckets: Buckets | None = None, ): self._length = length super().__init__(function_name, value_fn, buckets) + def function_name(self): return "dist_gen.truncated_string" + def name(self): return f"{self._name} [truncated to {self._length}]" + def nominal_kwargs(self): return { "subgen_fn": self._name, "params": {}, "length": self._length, } + def actual_kwargs(self): return { "subgen_fn": self._name, "params": {}, "length": self._length, } + def generate_data(self, count): - return [ - self._generator_function()[:self._length] - for _ in range(count) - ] + return [self._generator_function()[: self._length] for _ in range(count)] class MimesisDateTimeGenerator(MimesisGeneratorBase): - def __init__(self, column: Column, function_name: str, min_year: str, max_year: str, start: int, end: int): + def __init__( + self, + column: Column, + function_name: str, + min_year: str, + max_year: str, + start: int, + end: int, + ): """ :param column: The column to generate into :param function_name: The name of the mimesis function @@ -445,28 +492,35 @@ def make_singleton(_cls, column: Column, engine: Engine, function_name: str): min_year = f"MIN({extract_year})" with engine.connect() as connection: result = connection.execute( - text(f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}") + text( + f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" + ) ).first() if result is None or result.start is None or result.end is None: return [] - return [MimesisDateTimeGenerator( - column, - function_name, - min_year, - max_year, - int(result.start), - int(result.end), - )] + return [ + MimesisDateTimeGenerator( + column, + function_name, + min_year, + max_year, + int(result.start), + int(result.end), + ) + ] + def nominal_kwargs(self): return { "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', } + def actual_kwargs(self): return { "start": self._start, "end": self._end, } + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: return { f"{self._column.name}__start": { @@ -478,6 +532,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: "comment": f"Latest year found for column {self._column.name} in table {self._column.table.name}", }, } + def generate_data(self, count): return [ self._generator_function(start=self._start, end=self._end) @@ -496,6 +551,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return strings. """ + GENERATOR_NAMES = [ "address.calling_code", "address.city", @@ -529,6 +585,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.text", "text.word", ] + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -551,35 +608,48 @@ def get_generators(self, columns: list[Column], engine: Engine): fitness_fn = None length = column_type.length if length: - return list(map( - lambda gen: MimesisGeneratorTruncated(gen, length, fitness_fn, buckets), + return list( + map( + lambda gen: MimesisGeneratorTruncated( + gen, length, fitness_fn, buckets + ), + self.GENERATOR_NAMES, + ) + ) + return list( + map( + lambda gen: MimesisGenerator(gen, fitness_fn, buckets), self.GENERATOR_NAMES, - )) - return list(map( - lambda gen: MimesisGenerator(gen, fitness_fn, buckets), - self.GENERATOR_NAMES, - )) + ) + ) class MimesisFloatGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return floating point numbers. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] column = columns[0] if not isinstance(get_column_type(column), Numeric): return [] - return list(map(MimesisGenerator, [ - "person.height", - ])) + return list( + map( + MimesisGenerator, + [ + "person.height", + ], + ) + ) class MimesisDateGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return dates. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -594,6 +664,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return datetimes. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -601,13 +672,16 @@ def get_generators(self, columns: list[Column], engine: Engine): ct = get_column_type(column) if not isinstance(ct, DateTime): return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.datetime") + return MimesisDateTimeGenerator.make_singleton( + column, engine, "datetime.datetime" + ) class MimesisTimeGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return times. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -622,6 +696,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): """ All Mimesis generators that return integers. """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -633,7 +708,7 @@ def get_generators(self, columns: list[Column], engine: Engine): def fit_from_buckets(xs: list[float], ys: list[float]): - sum_diff_squared = sum(map(lambda t, a: (t - a)*(t - a), xs, ys)) + sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) return sum_diff_squared / (count * count) @@ -644,11 +719,13 @@ def __init__(self, table_name: str, column_name: str, buckets: Buckets): self.table_name = table_name self.column_name = column_name self.buckets = buckets + def nominal_kwargs(self): return { "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', } + def actual_kwargs(self): if self.buckets is None: return {} @@ -656,6 +733,7 @@ def actual_kwargs(self): "mean": self.buckets.mean, "sd": self.buckets.stddev, } + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: clauses = super().select_aggregate_clauses() return { @@ -669,6 +747,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", }, } + def fit(self, default=None): if self.buckets is None: return default @@ -676,9 +755,22 @@ def fit(self, default=None): class GaussianGenerator(ContinuousDistributionGenerator): - expected_buckets = [0.0227, 0.0441, 0.0918, 0.1499, 0.1915, 0.1915, 0.1499, 0.0918, 0.0441, 0.0227] + expected_buckets = [ + 0.0227, + 0.0441, + 0.0918, + 0.1499, + 0.1915, + 0.1915, + 0.1499, + 0.0918, + 0.0441, + 0.0227, + ] + def function_name(self): return "dist_gen.normal" + def generate_data(self, count): return [ dist_gen.normal(self.buckets.mean, self.buckets.stddev) @@ -687,9 +779,22 @@ def generate_data(self, count): class UniformGenerator(ContinuousDistributionGenerator): - expected_buckets = [0, 0.06698, 0.14434, 0.14434, 0.14434, 0.14434, 0.14434, 0.14434, 0.06698, 0] + expected_buckets = [ + 0, + 0.06698, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.06698, + 0, + ] + def function_name(self): return "dist_gen.uniform_ms" + def generate_data(self, count): return [ dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) @@ -701,6 +806,7 @@ class ContinuousDistributionGeneratorFactory(GeneratorFactory): """ All generators that want an average and standard deviation. """ + def _get_generators_from_buckets( self, _engine: Engine, @@ -725,36 +831,59 @@ def get_generators(self, columns: list[Column], engine: Engine) -> list[Generato buckets = Buckets.make_buckets(engine, table_name, column_name) if buckets is None: return [] - return self._get_generators_from_buckets(engine, table_name, column_name, buckets) + return self._get_generators_from_buckets( + engine, table_name, column_name, buckets + ) class LogNormalGenerator(Generator): - #TODO: figure out the real buckets here (this was from a random sample in R) - expected_buckets = [0, 0, 0, 0.28627, 0.40607, 0.14937, 0.06735, 0.03492, 0.01918, 0.03684] - def __init__(self, table_name: str, column_name: str, buckets: Buckets, logmean: float, logstddev: float): + # TODO: figure out the real buckets here (this was from a random sample in R) + expected_buckets = [ + 0, + 0, + 0, + 0.28627, + 0.40607, + 0.14937, + 0.06735, + 0.03492, + 0.01918, + 0.03684, + ] + + def __init__( + self, + table_name: str, + column_name: str, + buckets: Buckets, + logmean: float, + logstddev: float, + ): super().__init__() self.table_name = table_name self.column_name = column_name self.buckets = buckets self.logmean = logmean self.logstddev = logstddev + def function_name(self): return "dist_gen.lognormal" + def generate_data(self, count): - return [ - dist_gen.lognormal(self.logmean, self.logstddev) - for _ in range(count) - ] + return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] + def nominal_kwargs(self): return { "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', } + def actual_kwargs(self): return { "logmean": self.logmean, "logsd": self.logstddev, } + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: clauses = super().select_aggregate_clauses() return { @@ -768,6 +897,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: "comment": f"Standard deviation of logs of {self.column_name} from table {self.table_name}", }, } + def fit(self, default=None): if self.buckets is None: return default @@ -778,6 +908,7 @@ class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorF """ All generators that want an average and standard deviation of log data. """ + def _get_generators_from_buckets( self, engine: Engine, @@ -787,10 +918,12 @@ def _get_generators_from_buckets( ) -> list[Generator]: with engine.connect() as connection: result = connection.execute( - text("SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( - table=table_name, - column=column_name, - )) + text( + "SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( + table=table_name, + column=column_name, + ) + ) ).first() if result is None or result.logstddev is None: return [] @@ -806,7 +939,7 @@ def _get_generators_from_buckets( def zipf_distribution(total, bins): - basic_dist = list(map(lambda n: 1/n, range(1, bins + 1))) + basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) bd_remaining = sum(basic_dist) for b in basic_dist: # yield b/bd_remaining of the `total` remaining @@ -821,14 +954,15 @@ def zipf_distribution(total, bins): class ChoiceGenerator(Generator): STORE_COUNTS = False + def __init__( self, table_name, column_name, values, counts, - sample_count = None, - suppress_count = 0, + sample_count=None, + suppress_count=0, ): super().__init__() self.table_name = table_name @@ -868,19 +1002,23 @@ def get_estimated_counts(counts): """ The counts that we would expect if this distribution was the correct one. """ + def nominal_kwargs(self): return { "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } + def name(self): n = super().name() if self._annotation is None: return n return f"{n} [{self._annotation}]" + def actual_kwargs(self): return { "a": self.values, } + def custom_queries(self) -> dict[str, dict[str, str]]: qs = super().custom_queries() return { @@ -888,20 +1026,23 @@ def custom_queries(self) -> dict[str, dict[str, str]]: f"auto__{self.table_name}__{self.column_name}": { "query": self._query, "comment": self._comment, - } + }, } + def fit(self, default=None): return default if self._fit is None else self._fit + class ZipfChoiceGenerator(ChoiceGenerator): def get_estimated_counts(self, counts): return list(zipf_distribution(sum(counts), len(counts))) + def function_name(self): return "dist_gen.zipf_choice" + def generate_data(self, count): return [ - dist_gen.zipf_choice(self.values, len(self.values)) - for _ in range(count) + dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) ] @@ -917,34 +1058,35 @@ def uniform_distribution(total, bins): class UniformChoiceGenerator(ChoiceGenerator): def get_estimated_counts(self, counts): return list(uniform_distribution(sum(counts), len(counts))) + def function_name(self): return "dist_gen.choice" + def generate_data(self, count): - return [ - dist_gen.choice(self.values) - for _ in range(count) - ] + return [dist_gen.choice(self.values) for _ in range(count)] class WeightedChoiceGenerator(ChoiceGenerator): STORE_COUNTS = True + def get_estimated_counts(self, counts): return counts + def function_name(self): return "dist_gen.weighted_choice" + def generate_data(self, count): - return [ - dist_gen.weighted_choice(self.values) - for _ in range(count) - ] + return [dist_gen.weighted_choice(self.values) for _ in range(count)] class ChoiceGeneratorFactory(GeneratorFactory): """ All generators that want an average and standard deviation. """ + SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -954,16 +1096,20 @@ def get_generators(self, columns: list[Column], engine: Engine): generators = [] with engine.connect() as connection: results = connection.execute( - text("SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( - table=table_name, - column=column_name, - limit=MAXIMUM_CHOICES+1, - )) + text( + "SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( + table=table_name, + column=column_name, + limit=MAXIMUM_CHOICES + 1, + ) + ) ) if results is not None and results.rowcount <= MAXIMUM_CHOICES: values = [] # The values found counts = [] # The number or each value - cvs: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" + cvs: list[ + dict[str, any] + ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f if c != 0: @@ -980,19 +1126,27 @@ def get_generators(self, columns: list[Column], engine: Engine): WeightedChoiceGenerator(table_name, column_name, cvs, counts), ] results = connection.execute( - text("SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( - table=table_name, - column=column_name, - sample_count=self.SAMPLE_COUNT, - )) + text( + "SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( + table=table_name, + column=column_name, + sample_count=self.SAMPLE_COUNT, + ) + ) ) if results is not None: values = [] # All values found counts = [] # The number or each value - cvs: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times + cvs: list[ + dict[str, any] + ] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = ( + [] + ) # All values found more than SUPPRESS_COUNT times counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[dict[str, any]] = [] # list of dicts with keys "v" and "count" + cvs_not_suppressed: list[ + dict[str, any] + ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f if c != 0: @@ -1011,9 +1165,27 @@ def get_generators(self, columns: list[Column], engine: Engine): cvs_not_suppressed.append({"value": v, "count": c}) if counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT), - UniformChoiceGenerator(table_name, column_name, values, counts, sample_count=self.SAMPLE_COUNT), - WeightedChoiceGenerator(table_name, column_name, cvs, counts, sample_count=self.SAMPLE_COUNT), + ZipfChoiceGenerator( + table_name, + column_name, + values, + counts, + sample_count=self.SAMPLE_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + values, + counts, + sample_count=self.SAMPLE_COUNT, + ), + WeightedChoiceGenerator( + table_name, + column_name, + cvs, + counts, + sample_count=self.SAMPLE_COUNT, + ), ] if counts_not_suppressed: generators += [ @@ -1050,12 +1222,16 @@ def __init__(self, value): super().__init__() self.value = value self.repr = repr(value) + def function_name(self) -> str: return "dist_gen.constant" + def nominal_kwargs(self) -> dict[str, str]: return {"value": self.repr} + def actual_kwargs(self) -> dict[str, any]: return {"value": self.value} + def generate_data(self, count) -> list[any]: return [self.value for _ in range(count)] @@ -1064,6 +1240,7 @@ class ConstantGeneratorFactory(GeneratorFactory): """ Just the null generator """ + def get_generators(self, columns: list[Column], engine: Engine): if len(columns) != 1: return [] @@ -1116,7 +1293,7 @@ def actual_kwargs(self) -> dict[str, any]: """ The kwargs (summary statistics) this generator is instantiated with. """ - return { "cov": self._covariates } + return {"cov": self._covariates} def generate_data(self, count) -> list[any]: """ @@ -1145,12 +1322,12 @@ def query( self, table: str, columns: list[Column], - predicates: list[str]=[], - group_by_clause: str="", - constant_clauses: str="", - constants: str="", - suppress_count: int=1, - sample_count: int | None=None, + predicates: list[str] = [], + group_by_clause: str = "", + constant_clauses: str = "", + constants: str = "", + suppress_count: int = 1, + sample_count: int | None = None, ) -> str: """ Gets a query for the basics for multivariate normal/lognormal parameters. @@ -1175,15 +1352,13 @@ def query( multiples = "".join( f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[:iy+1]) - ) - means = "".join( - f", _q.m{i}" for i in range(len(columns)) + for ix, colx in enumerate(columns[: iy + 1]) ) + means = "".join(f", _q.m{i}" for i in range(len(columns))) covs = "".join( f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" for iy in range(len(columns)) - for ix in range(iy+1) + for ix in range(iy + 1) ) if sample_count is None: subquery = table + where @@ -1211,21 +1386,21 @@ def get_generators(self, columns: list[Column], engine: Engine): query = self.query(table, columns) with engine.connect() as connection: try: - covariates = connection.execute(text( - query - )).mappings().first() + covariates = connection.execute(text(query)).mappings().first() except Exception as e: logger.debug("SQL query %s failed with error %s", query, e) return [] if not covariates or covariates["c0_0"] is None: return [] - return [MultivariateNormalGenerator( - table, - column_names, - query, - covariates, - self.function_name(), - )] + return [ + MultivariateNormalGenerator( + table, + column_names, + query, + covariates, + self.function_name(), + ) + ] class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): @@ -1283,7 +1458,9 @@ def comment(self) -> str: else: where = f" where {text_list(self.excluded_columns.values())}" if len(self.included_numeric) == 1: - return f"Mean and variance for column {self.included_numeric[0].name}{where}." + return ( + f"Mean and variance for column {self.included_numeric[0].name}{where}." + ) return ( "Means and covariate matrix for the columns " f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" @@ -1298,7 +1475,7 @@ class NullPartitionedNormalGenerator(Generator): Generates data that matches the source data in missingness, choice of non-numeric data and numeric data. - + For the numeric data to be generated, samples of rows for each combination of non-numeric values and missingness. If any such combination has only one line in the source data (or sample of @@ -1307,15 +1484,16 @@ class NullPartitionedNormalGenerator(Generator): (although if the data is all non-numeric values and nulls, single rows are used because no covariate matrix is required for this). """ + def __init__( self, query_name: str, partitions: dict[int, RowPartition], - function_name: str="grouped_multivariate_lognormal", - name_suffix: str | None=None, - partition_count_query: str | None=None, - partition_counts: Sequence[RowMapping] | None=None, - partition_count_comment: str | None=None, + function_name: str = "grouped_multivariate_lognormal", + name_suffix: str | None = None, + partition_count_query: str | None = None, + partition_counts: Sequence[RowMapping] | None = None, + partition_count_comment: str | None = None, ): self._query_name = query_name self._partitions = partitions @@ -1358,7 +1536,7 @@ def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) "constants_at": partition.constant_outputs, "subgen": f'"{self._function_name}"', "params": covariates, - } + }, } def _count_query_name(self): @@ -1407,7 +1585,7 @@ def _actual_kwargs_with_combinations(self, partition: RowPartition): "name": self._function_name, "params": { "covs": partition.covariates, - } + }, } return { "count": count, @@ -1418,7 +1596,7 @@ def _actual_kwargs_with_combinations(self, partition: RowPartition): "params": { "covs": partition.covariates, }, - } + }, } def actual_kwargs(self) -> dict[str, any]: @@ -1438,10 +1616,7 @@ def generate_data(self, count) -> list[any]: Generate 'count' random data points for this column. """ kwargs = self.actual_kwargs() - return [ - dist_gen.alternatives(**kwargs) - for _ in range(count) - ] + return [dist_gen.alternatives(**kwargs) for _ in range(count)] def fit(self, default=None) -> float | None: return default @@ -1449,11 +1624,11 @@ def fit(self, default=None) -> float | None: def is_numeric(col: Column) -> bool: ct = get_column_type(col) - return ( - isinstance(ct, Numeric) or isinstance(ct, Integer) - ) and not col.foreign_keys + return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys + + +T = TypeVar("T") -T = TypeVar('T') def powerset(input: Iterable[T]) -> Iterable[Iterable[T]]: """Returns a list of all sublists of""" @@ -1465,6 +1640,7 @@ class NullableColumn: """ A reference to a nullable column whose nullability is part of a partitioning. """ + column: Column # The bit (power of two) of the number of the partition in the partition sizes list bitmask: int @@ -1474,13 +1650,12 @@ class NullPatternPartition: """ The definition of a partition (in other words, what makes it not another partition) """ + def __init__( - self, - columns: Iterable[Column], - partition_nonnulls: Iterable[NullableColumn] + self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] ): self.index = sum(nc.bitmask for nc in partition_nonnulls) - nonnull_columns = { nc.column.name for nc in partition_nonnulls } + nonnull_columns = {nc.column.name for nc in partition_nonnulls} self.included_numeric: list[Column] = [] self.included_choice: dict[int, str] = {} self.group_by_clause = "" @@ -1535,17 +1710,21 @@ def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: out: list[NullableColumn] = [] for col in columns: if col.nullable: - out.append(NullableColumn( - column=col, - bitmask=2 ** len(out), - )) + out.append( + NullableColumn( + column=col, + bitmask=2 ** len(out), + ) + ) return out - def get_partition_count_query(self, ncs: list[NullableColumn], table: str, where: str | None=None) -> str: + def get_partition_count_query( + self, ncs: list[NullableColumn], table: str, where: str | None = None + ) -> str: """ Returns a SQL expression returning columns ``count`` and ``index``. - Each row returned represents one of the null pattern partitions. + Each row returned represents one of the null pattern partitions. ``index`` is the bitmask of all those nullable columns that are not null for this partition, and ``count`` is the total number of rows in this partition. """ @@ -1576,7 +1755,7 @@ def get_generators(self, columns: list[Column], engine: Engine): columns=partition_def.included_numeric, predicates=partition_def.predicates, group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, + constants=partition_def.constants, constant_clauses=partition_def.constant_clauses, ) row_partitions_maximal[partition_def.index] = RowPartition( @@ -1592,7 +1771,7 @@ def get_generators(self, columns: list[Column], engine: Engine): columns=partition_def.included_numeric, predicates=partition_def.predicates, group_by_clause=partition_def.group_by_clause, - constants = partition_def.constants, + constants=partition_def.constants, constant_clauses=partition_def.constant_clauses, suppress_count=self.SUPPRESS_COUNT, sample_count=self.SAMPLE_COUNT, @@ -1608,43 +1787,49 @@ def get_generators(self, columns: list[Column], engine: Engine): gens = [] try: with engine.connect() as connection: - partition_query_max = self.get_partition_count_query(nullable_columns, table) - partition_count_max_results = connection.execute( - text(partition_query_max) - ).mappings().fetchall() + partition_query_max = self.get_partition_count_query( + nullable_columns, table + ) + partition_count_max_results = ( + connection.execute(text(partition_query_max)).mappings().fetchall() + ) count_comment = f"Number of rows for each combination of the columns { {nc.column.name for nc in nullable_columns} } of the table {table} being null" if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - )) + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_maximal, + self.function_name(), + partition_count_query=partition_query_max, + partition_counts=partition_count_max_results, + partition_count_comment=count_comment, + ) + ) partition_query_ss = self.get_partition_count_query( nullable_columns, table, - where=f"WHERE {self.SUPPRESS_COUNT} < count" + where=f"WHERE {self.SUPPRESS_COUNT} < count", + ) + partition_count_ss_results = ( + connection.execute(text(partition_query_ss)).mappings().fetchall() ) - partition_count_ss_results = connection.execute( - text(partition_query_ss) - ).mappings().fetchall() if self._execute_partition_queries(connection, row_partitions_ss): - gens.append(NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - )) + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_ss, + self.function_name(), + name_suffix="sampled and suppressed", + partition_count_query=partition_query_ss, + partition_counts=partition_count_ss_results, + partition_count_comment=count_comment, + ) + ) except sqlalchemy.exc.DatabaseError as exc: logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) return [] return gens - + def _execute_partition_queries( self, connection: Connection, @@ -1656,9 +1841,7 @@ def _execute_partition_queries( """ found_nonzero = False for rp in partitions.values(): - rp.covariates = connection.execute(text( - rp.query - )).mappings().fetchall() + rp.covariates = connection.execute(text(rp.query)).mappings().fetchall() if not rp.covariates or rp.covariates[0]["count"] is None: rp.covariates = [{"count": 0}] else: @@ -1682,19 +1865,21 @@ def query_var(self, column: str) -> str: @lru_cache(1) def everything_factory(): - return MultiGeneratorFactory([ - MimesisStringGeneratorFactory(), - MimesisIntegerGeneratorFactory(), - MimesisFloatGeneratorFactory(), - MimesisDateGeneratorFactory(), - MimesisDateTimeGeneratorFactory(), - MimesisTimeGeneratorFactory(), - ContinuousDistributionGeneratorFactory(), - ContinuousLogDistributionGeneratorFactory(), - ChoiceGeneratorFactory(), - ConstantGeneratorFactory(), - MultivariateNormalGeneratorFactory(), - MultivariateLogNormalGeneratorFactory(), - NullPartitionedNormalGeneratorFactory(), - NullPartitionedLogNormalGeneratorFactory(), - ]) + return MultiGeneratorFactory( + [ + MimesisStringGeneratorFactory(), + MimesisIntegerGeneratorFactory(), + MimesisFloatGeneratorFactory(), + MimesisDateGeneratorFactory(), + MimesisDateTimeGeneratorFactory(), + MimesisTimeGeneratorFactory(), + ContinuousDistributionGeneratorFactory(), + ContinuousLogDistributionGeneratorFactory(), + ChoiceGeneratorFactory(), + ConstantGeneratorFactory(), + MultivariateNormalGeneratorFactory(), + MultivariateLogNormalGeneratorFactory(), + NullPartitionedNormalGeneratorFactory(), + NullPartitionedLogNormalGeneratorFactory(), + ] + ) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 5e3fb89..4f5e8c5 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -10,7 +10,7 @@ import sqlalchemy from prettytable import PrettyTable -from sqlalchemy import Column, MetaData, Table, text, ForeignKey +from sqlalchemy import Column, ForeignKey, MetaData, Table, text from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( @@ -26,15 +26,18 @@ # See https://github.com/pyreadline3/pyreadline3/issues/37 try: import readline + if not hasattr(readline, "backend"): readline.backend = "readline" except: pass + def or_default(v, d): - """ Returns v if it isn't None, otherwise d. """ + """Returns v if it isn't None, otherwise d.""" return d if v is None else v + class TableType(Enum): GENERATE = "generate" IGNORE = "ignore" @@ -42,6 +45,7 @@ class TableType(Enum): PRIVATE = "private" EMPTY = "empty" + TYPE_LETTER = { TableType.GENERATE: "G", TableType.IGNORE: "I", @@ -58,6 +62,7 @@ class TableType(Enum): TableType.EMPTY: "(table: {} (empty))", } + @dataclass class TableEntry: name: str # name of the table @@ -67,15 +72,19 @@ class AskSaveCmd(cmd.Cmd): intro = "Do you want to save this configuration?" prompt = "(yes/no/cancel) " file = None + def __init__(self): super().__init__() self.result = "" + def do_yes(self, _arg): self.result = "yes" return True + def do_no(self, _arg): self.result = "no" return True + def do_cancel(self, _arg): self.result = "cancel" return True @@ -98,7 +107,9 @@ class DbCmd(ABC, cmd.Cmd): def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: ... - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__() self.config = config self.metadata = metadata @@ -115,30 +126,33 @@ def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Ma self.table_entries.append(entry) self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): self.engine.dispose() def print(self, text: str, *args, **kwargs): print(text.format(*args, **kwargs)) + def print_table(self, headings: list[str], rows: list[list[str]]): output = PrettyTable() output.field_names = headings for row in rows: output.add_row(row) print(output) + def print_table_by_columns(self, columns: dict[str, list[str]]): output = PrettyTable() row_count = max([len(col) for col in columns.values()]) for field_name, data in columns.items(): output.add_column(field_name, data + [None] * (row_count - len(data))) print(output) + def print_results(self, result): - self.print_table( - list(result.keys()), - [list(row) for row in result.all()] - ) + self.print_table(list(result.keys()), [list(row) for row in result.all()]) + def ask_save(self): ask = AskSaveCmd() ask.cmdloop() @@ -150,39 +164,51 @@ def set_table_index(self, index) -> bool: self.set_prompt() return True return False + def next_table(self, report="No more tables"): if not self.set_table_index(self.table_index + 1): self.print(report) return False return True + def table_name(self): return self.table_entries[self.table_index].name + def table_metadata(self) -> Table: return self.metadata.tables[self.table_name()] + def get_column_names(self) -> list[str]: - return [ - col.name - for col in self.table_metadata().columns - ] + return [col.name for col in self.table_metadata().columns] + def report_columns(self): - self.print_table(["name", "type", "primary", "nullable", "foreign key"], [ - [name, str(col.type), col.primary_key, col.nullable, ", ".join( - [fk_column_name(fk) for fk in col.foreign_keys] - )] - for name, col in self.table_metadata().columns.items() - ]) + self.print_table( + ["name", "type", "primary", "nullable", "foreign key"], + [ + [ + name, + str(col.type), + col.primary_key, + col.nullable, + ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), + ] + for name, col in self.table_metadata().columns.items() + ], + ) + def get_table_config(self, table_name: str) -> dict[str, any]: ts = self.config.get("tables", None) if type(ts) is not dict: return {} t = ts.get(table_name) return t if type(t) is dict else {} + def set_table_config(self, table_name: str, config: dict[str, any]): ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} return ts[table_name] = config + def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: src_stats = self.config.get("src-stats", []) new_src_stats = [] @@ -191,6 +217,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: new_src_stats.append(stat) self.config["src-stats"] = new_src_stats return new_src_stats + def get_nonnull_columns(self, table_name: str): metadata_table = self.metadata.tables[table_name] return [ @@ -198,52 +225,59 @@ def get_nonnull_columns(self, table_name: str): for name, column in metadata_table.columns.items() if column.nullable ] + def find_entry_index_by_table_name(self, table_name) -> int | None: return next( - (i for i,entry in enumerate(self.table_entries) if entry.name == table_name), + ( + i + for i, entry in enumerate(self.table_entries) + if entry.name == table_name + ), None, ) + def find_entry_by_table_name(self, table_name) -> TableEntry | None: for e in self.table_entries: if e.name == table_name: return e return None + def do_counts(self, _arg): "Report the column names with the counts of nulls in them" if len(self.table_entries) <= self.table_index: return table_name = self.table_name() nonnull_columns = self.get_nonnull_columns(table_name) - colcounts = [ - ", COUNT({0}) AS {0}".format(nnc) - for nnc in nonnull_columns - ] + colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] with self.engine.connect() as connection: result = connection.execute( - text("SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( - table=table_name, - colcounts="".join(colcounts), - )) + text( + "SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( + table=table_name, + colcounts="".join(colcounts), + ) + ) ).first() if result is None: self.print("Could not count rows in table {0}", table_name) return row_count = result.row_count self.print(self.ROW_COUNT_MSG, row_count) - self.print_table(["Column", "NULL count"], [ - [name, row_count - count] - for name, count in result._mapping.items() - if name != "row_count" - ]) + self.print_table( + ["Column", "NULL count"], + [ + [name, row_count - count] + for name, count in result._mapping.items() + if name != "row_count" + ], + ) def do_select(self, arg): "Run a select query over the database and show the first 50 results" MAX_SELECT_ROWS = 50 with self.engine.connect() as connection: try: - result = connection.execute( - text("SELECT " + arg) - ) + result = connection.execute(text("SELECT " + arg)) except sqlalchemy.exc.DatabaseError as exc: self.print("Failed to execute: {}", exc) return @@ -252,10 +286,7 @@ def do_select(self, arg): if 50 < row_count: self.print("Showing the first {} rows", MAX_SELECT_ROWS) fields = list(result.keys()) - rows = [ - row._tuple() - for row in result.fetchmany(MAX_SELECT_ROWS) - ] + rows = [row._tuple() for row in result.fetchmany(MAX_SELECT_ROWS)] self.print_table(fields, rows) def do_peek(self, arg: str): @@ -274,30 +305,25 @@ def do_peek(self, arg: str): nonnulls = [cn + " IS NOT NULL" for cn in col_names] with self.engine.connect() as connection: query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( - cols=",".join(col_names), - table=table_name, - where="WHERE" if nonnulls else "", - nonnull=" OR ".join(nonnulls), - max=MAX_PEEK_ROWS, - ) + cols=",".join(col_names), + table=table_name, + where="WHERE" if nonnulls else "", + nonnull=" OR ".join(nonnulls), + max=MAX_PEEK_ROWS, + ) try: result = connection.execute(text(query)) except Exception as exc: self.print(f'SQL query "{query}" caused exception {exc}') return - rows = [ - row._tuple() - for row in result.fetchmany(MAX_PEEK_ROWS) - ] + rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] self.print_table(list(result.keys()), rows) def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int): if len(self.table_entries) <= self.table_index: return [] return [ - col - for col in self.table_metadata().columns.keys() - if col.startswith(text) + col for col in self.table_metadata().columns.keys() if col.startswith(text) ] @@ -306,6 +332,7 @@ class TableCmdTableEntry(TableEntry): old_type: TableType new_type: TableType + class TableCmd(DbCmd): intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" doc_leader = """Use the commands 'ignore', 'vocabulary', @@ -316,10 +343,16 @@ class TableCmd(DbCmd): to exit this program.""" prompt = "(tableconf) " file = None - WARNING_TEXT_VOCAB_TO_NON_VOCAB = "Vocabulary table {0} references non-vocabulary table {1}" - WARNING_TEXT_NON_EMPTY_TO_EMPTY = "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( + "Vocabulary table {0} references non-vocabulary table {1}" + ) + WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( + "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + ) WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" - WARNING_TEXT_POTENTIAL_PROBLEMS = "NOTE: The following table types might cause problems later:" + WARNING_TEXT_POTENTIAL_PROBLEMS = ( + "NOTE: The following table types might cause problems later:" + ) NOTE_TEXT_NO_CHANGES = "You have made no changes." NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" @@ -334,7 +367,9 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: return TableCmdTableEntry(name, TableType.EMPTY, TableType.EMPTY) return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -344,16 +379,21 @@ def set_prompt(self): self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) else: self.prompt = "(table) " + def set_type(self, t_type: TableType): if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type + def _copy_entries(self) -> None: for entry in self.table_entries: entry: TableCmdTableEntry if entry.old_type != entry.new_type: table = self.get_table_config(entry.name) - if entry.old_type == TableType.EMPTY and table.get("num_rows_per_pass", 1) == 0: + if ( + entry.old_type == TableType.EMPTY + and table.get("num_rows_per_pass", 1) == 0 + ): table["num_rows_per_pass"] = 1 if entry.new_type == TableType.IGNORE: table["ignore"] = True @@ -381,13 +421,11 @@ def _copy_entries(self) -> None: def _get_referenced_tables(self, from_table_name: str) -> set[str]: from_meta = self.metadata.tables[from_table_name] return { - fk.column.table.name - for col in from_meta.columns - for fk in col.foreign_keys + fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys } def _sanity_check_failures(self) -> list[tuple[str, str, str]]: - """ Find tables that reference each other that should not given their types. """ + """Find tables that reference each other that should not given their types.""" failures = [] for from_entry in self.table_entries: from_entry: TableCmdTableEntry @@ -396,16 +434,21 @@ def _sanity_check_failures(self) -> list[tuple[str, str, str]]: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: to_entry = self.find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type != TableType.VOCABULARY: - failures.append(( - self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - from_entry.name, - to_entry.name, - )) + if ( + to_entry is not None + and to_entry.new_type != TableType.VOCABULARY + ): + failures.append( + ( + self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + from_entry.name, + to_entry.name, + ) + ) return failures def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: - """ Find tables that reference each other that might cause problems given their types. """ + """Find tables that reference each other that might cause problems given their types.""" warnings = [] for from_entry in self.table_entries: from_entry: TableCmdTableEntry @@ -414,15 +457,19 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: to_entry = self.find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type in {TableType.EMPTY, TableType.IGNORE}: - warnings.append(( - self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - from_entry.name, - to_entry.name, - )) + if to_entry is not None and to_entry.new_type in { + TableType.EMPTY, + TableType.IGNORE, + }: + warnings.append( + ( + self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + from_entry.name, + to_entry.name, + ) + ) return warnings - def do_quit(self, _arg): "Check the updates, save them if desired and quit the configurer." count = 0 @@ -440,12 +487,12 @@ def do_quit(self, _arg): failures = self._sanity_check_failures() if failures: self.print(self.WARNING_TEXT_PROBLEMS_EXIST) - for (text, from_t, to_t) in failures: + for text, from_t, to_t in failures: self.print(text, from_t, to_t) warnings = self._sanity_check_warnings() if warnings: self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) - for (text, from_t, to_t) in warnings: + for text, from_t, to_t in warnings: self.print(text, from_t, to_t) reply = self.ask_save() if reply == "yes": @@ -454,6 +501,7 @@ def do_quit(self, _arg): if reply == "no": return True return False + def do_tables(self, _arg): "list the tables with their types" for entry in self.table_entries: @@ -461,6 +509,7 @@ def do_tables(self, _arg): new = entry.new_type becomes = " " if old == new else "->" + TYPE_LETTER[new] self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) + def do_next(self, arg): "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: @@ -472,44 +521,51 @@ def do_next(self, arg): self.set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) + def complete_next(self, text, line, begidx, endidx): return [ - entry.name - for entry in self.table_entries - if entry.name.startswith(text) + entry.name for entry in self.table_entries if entry.name.startswith(text) ] + def do_previous(self, _arg): "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) + def do_ignore(self, _arg): "Set the current table as ignored, and go to the next table" self.set_type(TableType.IGNORE) self.print("Table {} set as ignored", self.table_name()) self.next_table() + def do_vocabulary(self, _arg): "Set the current table as a vocabulary table, and go to the next table" self.set_type(TableType.VOCABULARY) self.print("Table {} set to be a vocabulary table", self.table_name()) self.next_table() + def do_private(self, _arg): "Set the current table as a primary private table (such as the table of patients)" self.set_type(TableType.PRIVATE) self.print("Table {} set to be a primary private table", self.table_name()) self.next_table() + def do_generate(self, _arg): "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() + def do_empty(self, _arg): "Set the current table as empty; no generators will be run for it" self.set_type(TableType.EMPTY) self.print("Table {} empty", self.table_name()) self.next_table() + def do_columns(self, _arg): "Report the column names and metadata" self.report_columns() + def do_data(self, arg: str): """ Report some data. @@ -549,15 +605,13 @@ def do_data(self, arg: str): if number is None: number = 48 self.print_column_data(column, number, min_length) + def complete_data(self, text, line, begidx, endidx): - previous_parts = line[:begidx - 1].split() + previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] table_metadata = self.table_metadata() - return [ - k for k in table_metadata.columns.keys() - if k.startswith(text) - ] + return [k for k in table_metadata.columns.keys() if k.startswith(text)] def print_column_data(self, column: str, count: int, min_length: int): where = f"WHERE {column} IS NOT NULL" @@ -568,22 +622,26 @@ def print_column_data(self, column: str, count: int, min_length: int): ) with self.engine.connect() as connection: result = connection.execute( - text("SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - column=column, - count=count, - where=where, - )) + text( + "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( + table=self.table_name(), + column=column, + count=count, + where=where, + ) + ) ) self.columnize([str(x[0]) for x in result.all()]) def print_row_data(self, count: int): with self.engine.connect() as connection: result = connection.execute( - text("SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - count=count, - )) + text( + "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( + table=self.table_name(), + count=count, + ) + ) ) if result is None: self.print("No rows in this table!") @@ -591,7 +649,9 @@ def print_row_data(self, count: int): self.print_results(result) -def update_config_tables(src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): +def update_config_tables( + src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping +): with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() return tc.config @@ -599,8 +659,8 @@ def update_config_tables(src_dsn: str, src_schema: str, metadata: MetaData, conf @dataclass class MissingnessType: - SAMPLED="column_presence.sampled" - SAMPLED_QUERY=( + SAMPLED = "column_presence.sampled" + SAMPLED_QUERY = ( "SELECT COUNT(*) AS row_count, {result_names} FROM " "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" " AS __t GROUP BY {result_names}" @@ -609,16 +669,13 @@ class MissingnessType: query: str comment: str columns: list[str] + @classmethod def sampled_query(cls, table, count, column_names) -> str: - result_names = ", ".join([ - "{0}__is_null".format(c) - for c in column_names - ]) - column_is_nulls = ", ".join([ - "{0} IS NULL AS {0}__is_null".format(c) - for c in column_names - ]) + result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) + column_is_nulls = ", ".join( + ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] + ) return cls.SAMPLED_QUERY.format( result_names=result_names, column_is_nulls=column_is_nulls, @@ -644,8 +701,10 @@ class MissingnessCmd(DbCmd): file = None PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') - def find_missingness_query(self, missingness_generator: Mapping) -> tuple[str | None, str | None] | None: - """ Find query and comment from src-stats for the passed missingness generator. """ + def find_missingness_query( + self, missingness_generator: Mapping + ) -> tuple[str | None, str | None] | None: + """Find query and comment from src-stats for the passed missingness generator.""" kwargs = missingness_generator.get("kwargs", {}) patterns = kwargs.get("patterns", "") pattern_match = self.PATTERN_RE.match(patterns) @@ -655,6 +714,7 @@ def find_missingness_query(self, missingness_generator: Mapping) -> tuple[str | if src_stat.get("name") == key: return (src_stat.get("query", None), src_stat.get("comment", None)) return None + def make_table_entry(self, name: str, table: Mapping) -> TableEntry: if table.get("ignore", False): return None @@ -695,7 +755,9 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: new_type=old, ) - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -709,10 +771,12 @@ def set_prompt(self): self.prompt = "(missingness for {0}: {1}) ".format(entry.name, nt.name) else: self.prompt = "(missingness) " + def set_type(self, t_type: TableType): if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type + def _copy_entries(self) -> None: src_stats = self._remove_prefix_src_stats("missing_auto__") for entry in self.table_entries: @@ -722,16 +786,26 @@ def _copy_entries(self) -> None: table.pop("missingness_generators", None) else: src_stat_key = "missing_auto__{0}__0".format(entry.name) - table["missingness_generators"] = [{ - "name": entry.new_type.name, - "kwargs": {"patterns": 'SRC_STATS["{0}"]["results"]'.format(src_stat_key)}, - "columns": entry.new_type.columns, - }] - src_stats.append({ - "name": src_stat_key, - "query": entry.new_type.query, - "comments": [] if entry.new_type.comment is None else [entry.new_type.comment], - }) + table["missingness_generators"] = [ + { + "name": entry.new_type.name, + "kwargs": { + "patterns": 'SRC_STATS["{0}"]["results"]'.format( + src_stat_key + ) + }, + "columns": entry.new_type.columns, + } + ] + src_stats.append( + { + "name": src_stat_key, + "query": entry.new_type.query, + "comments": [] + if entry.new_type.comment is None + else [entry.new_type.comment], + } + ) self.set_table_config(entry.name, table) def do_quit(self, _arg): @@ -741,9 +815,17 @@ def do_quit(self, _arg): if entry.old_type != entry.new_type: count += 1 if entry.old_type is None: - self.print("Putting generator {0} on table {1}", entry.name, entry.new_type.name) + self.print( + "Putting generator {0} on table {1}", + entry.name, + entry.new_type.name, + ) elif entry.new_type is None: - self.print("Deleting generator {1} from table {0}", entry.name, entry.old_type.name) + self.print( + "Deleting generator {1} from table {0}", + entry.name, + entry.old_type.name, + ) else: self.print( "Changing {0} from {1} to {2}", @@ -760,6 +842,7 @@ def do_quit(self, _arg): if reply == "no": return True return False + def do_tables(self, arg): "list the tables with their types" for entry in self.table_entries: @@ -767,27 +850,32 @@ def do_tables(self, arg): new = "-" if entry.new_type is None else entry.new_type.name desc = new if old == new else "{0}->{1}".format(old, new) self.print("{0} {1}", entry.name, desc) + def do_next(self, arg): "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: # Find the index of the table called _arg, if any - index = next((i for i,entry in enumerate(self.table_entries) if entry.name == arg), None) + index = next( + (i for i, entry in enumerate(self.table_entries) if entry.name == arg), + None, + ) if index is None: self.print(self.ERROR_NO_SUCH_TABLE, arg) return self.set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) + def complete_next(self, text, line, begidx, endidx): return [ - entry.name - for entry in self.table_entries - if entry.name.startswith(text) + entry.name for entry in self.table_entries if entry.name.startswith(text) ] + def do_previous(self, _arg): "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) + def _set_type(self, name, query, comment): if len(self.table_entries) <= self.table_index: return @@ -798,11 +886,13 @@ def _set_type(self, name, query, comment): comment=comment, columns=self.get_nonnull_columns(entry.name), ) + def _set_none(self): if len(self.table_entries) <= self.table_index: return entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] entry.new_type = None + def do_sampled(self, arg: str): """ Set the current table missingness as 'sampled', and go to the next table. @@ -819,7 +909,10 @@ def do_sampled(self, arg: str): elif arg.isdecimal(): count = int(arg) else: - self.print("Error: sampled can be used alone or with an integer argument. {0} is not permitted", arg) + self.print( + "Error: sampled can be used alone or with an integer argument. {0} is not permitted", + arg, + ) return self._set_type( MissingnessType.SAMPLED, @@ -828,10 +921,11 @@ def do_sampled(self, arg: str): count, self.get_nonnull_columns(entry.name), ), - f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}" + f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}", ) self.print("Table {} set to sampled missingness", self.table_name()) self.next_table() + def do_none(self, _arg): "Set the current table to have no missingness, and go to the next table" self._set_none() @@ -839,7 +933,9 @@ def do_none(self, _arg): self.next_table() -def update_missingness(src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): +def update_missingness( + src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping +): with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() return mc.config @@ -850,11 +946,13 @@ class GeneratorInfo: columns: list[str] gen: Generator | None + @dataclass class GeneratorCmdTableEntry(TableEntry): old_generators: list[GeneratorInfo] new_generators: list[GeneratorInfo] + class GeneratorCmd(DbCmd): intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the @@ -883,7 +981,9 @@ class GeneratorCmd(DbCmd): ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" PROPOSE_NOTHING = "No proposed generators, sorry." - SRC_STAT_RE = re.compile(r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?') + SRC_STAT_RE = re.compile( + r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' + ) def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None: if table.get("ignore", False): @@ -902,35 +1002,47 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None gen_name = rg.get("name", None) if gen_name: ca = rg.get("columns_assigned", []) - collist: list[str] = [ca] if isinstance(ca, str) else [str(c) for c in ca] + collist: list[str] = ( + [ca] if isinstance(ca, str) else [str(c) for c in ca] + ) colset: set[str] = set(collist) for unknown in colset - column_set: logger.warning( "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, gen_name, unknown + table_name, + gen_name, + unknown, ) for mult in columns_assigned_so_far & colset: logger.warning( - "table '%s' has column '%s' assigned to multiple times", table_name, mult + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, ) actual_collist = [c for c in collist if c in columns] if actual_collist: gen = PredefinedGenerator(table, rg, self.config) - new_generator_infos.append(GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - )) - old_generator_infos.append(GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - )) + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) + old_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) columns_assigned_so_far |= colset for colname in columns: if colname not in columns_assigned_so_far: - new_generator_infos.append(GeneratorInfo( - columns=[colname], - gen=None, - )) + new_generator_infos.append( + GeneratorInfo( + columns=[colname], + gen=None, + ) + ) if len(new_generator_infos) == 0: return None return GeneratorCmdTableEntry( @@ -939,7 +1051,9 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None new_generators=new_generator_infos, ) - def __init__(self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping): + def __init__( + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + ): super().__init__(src_dsn, src_schema, metadata, config) self.generator_index = 0 self.generators_valid_columns = None @@ -957,7 +1071,10 @@ def previous_table(self): if ret: table = self.get_table() if table is None: - self.print("Internal error! table {0} does not have any generators!", self.table_index) + self.print( + "Internal error! table {0} does not have any generators!", + self.table_index, + ) return False self.generator_index = len(table.new_generators) - 1 else: @@ -985,10 +1102,7 @@ def column_metadata(self) -> list[Column]: table = self.table_metadata() if table is None: return [] - return [ - table.columns[name] - for name in self.get_column_names() - ] + return [table.columns[name] for name in self.get_column_names()] def set_prompt(self): (table_name, gen_info) = self.get_table_and_generator() @@ -1000,8 +1114,7 @@ def set_prompt(self): return table = self.table_metadata() columns = [ - c + "[pk]" if table.columns[c].primary_key else c - for c in gen_info.columns + c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns ] gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " @@ -1020,11 +1133,15 @@ def _copy_entries(self) -> None: new_gens.append(generator.gen) cqs = generator.gen.custom_queries() for cq_key, cq in cqs.items(): - src_stats.append({ - "name": cq_key, - "query": cq["query"], - "comments": [cq["comment"]] if "comment" in cq and cq["comment"] else [], - }) + src_stats.append( + { + "name": cq_key, + "query": cq["query"], + "comments": [cq["comment"]] + if "comment" in cq and cq["comment"] + else [], + } + ) rg = { "name": generator.gen.function_name(), "columns_assigned": generator.columns, @@ -1035,16 +1152,18 @@ def _copy_entries(self) -> None: rgs.append(rg) aq = self._get_aggregate_query(new_gens, entry.name) if aq: - src_stats.append({ - "name": f"auto__{entry.name}", - "query": aq, - "comments": [ - q["comment"] - for gen in new_gens - for q in gen.select_aggregate_clauses().values() - if "comment" in q and q["comment"] is not None - ], - }) + src_stats.append( + { + "name": f"auto__{entry.name}", + "query": aq, + "comments": [ + q["comment"] + for gen in new_gens + for q in gen.select_aggregate_clauses().values() + if "comment" in q and q["comment"] is not None + ], + } + ) table_config = self.get_table_config(entry.name) if rgs: table_config["row_generators"] = rgs @@ -1053,8 +1172,10 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table_config) self.config["src-stats"] = src_stats - def _find_old_generator(self, entry: GeneratorCmdTableEntry, columns) -> Generator | None: - """ Find any generator that previously assigned to these exact same columns. """ + def _find_old_generator( + self, entry: GeneratorCmdTableEntry, columns + ) -> Generator | None: + """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) for gen in entry.old_generators: if frozenset(gen.columns) == fc: @@ -1139,12 +1260,18 @@ def do_info(self, _arg): "nullable" if cm.nullable else "not nullable", ) if cm.primary_key: - self.print("It is a primary key, which usually does not need a generator (it will auto-increment)") + self.print( + "It is a primary key, which usually does not need a generator (it will auto-increment)" + ) if cm.foreign_keys: fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] - self.print("It is a foreign key referencing column {0}", ", ".join(fk_names)) + self.print( + "It is a foreign key referencing column {0}", ", ".join(fk_names) + ) if len(fk_names) == 1 and not cm.primary_key: - self.print("You do not need a generator if you just want a uniform choice over the referenced table's rows") + self.print( + "You do not need a generator if you just want a uniform choice over the referenced table's rows" + ) def _get_table_index(self, table_name: str) -> int | None: for n, entry in enumerate(self.table_entries): @@ -1196,7 +1323,7 @@ def do_next(self, arg): self._go_next() def do_n(self, arg): - """ Synonym for next """ + """Synonym for next""" self.do_next(arg) def complete_n(self, text: str, line: str, begidx: int, endidx: int): @@ -1248,7 +1375,7 @@ def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): return table_names + column_names def do_previous(self, _arg): - """ Go to the previous generator """ + """Go to the previous generator""" if self.generator_index == 0: self.previous_table() else: @@ -1256,11 +1383,14 @@ def do_previous(self, _arg): self.set_prompt() def do_b(self, arg): - """ Synonym for previous """ + """Synonym for previous""" self.do_previous(arg) def _generators_valid(self) -> bool: - return self.generators_valid_columns == (self.table_index, self.get_column_names()) + return self.generators_valid_columns == ( + self.table_index, + self.get_column_names(), + ) def _get_generator_proposals(self) -> list[Generator]: if not self._generators_valid(): @@ -1270,7 +1400,10 @@ def _get_generator_proposals(self) -> list[Generator]: gens = everything_factory().get_generators(columns, self.engine) gens.sort(key=lambda g: g.fit(9999)) self.generators = gens - self.generators_valid_columns = (self.table_index, self.get_column_names().copy()) + self.generators_valid_columns = ( + self.table_index, + self.get_column_names().copy(), + ) return self.generators def _print_privacy(self): @@ -1316,7 +1449,7 @@ def do_compare(self, arg: str): self.print_table_by_columns(comparison) def do_c(self, arg): - """ Synonym for compare. """ + """Synonym for compare.""" self.do_compare(arg) def _print_values_queried(self, table_name: str, n: int, gen: Generator): @@ -1354,7 +1487,11 @@ def _print_custom_queries(self, gen: Generator) -> None: actual, ) for cq_key, cq in cqs.items(): - self.print("{0}; providing the following values: {1}", cq["query"], cq_key2args[cq_key]) + self.print( + "{0}; providing the following values: {1}", + cq["query"], + cq_key2args[cq_key], + ) def _get_custom_queries_from(self, out, nominal, actual): if type(nominal) is str: @@ -1375,7 +1512,9 @@ def _get_custom_queries_from(self, out, nominal, actual): if k in actual: self._get_custom_queries_from(out, v, actual[k]) - def _get_aggregate_query(self, gens: list[Generator], table_name: str) -> str | None: + def _get_aggregate_query( + self, gens: list[Generator], table_name: str + ) -> str | None: clauses = [ f'{q["clause"]} AS {n}' for gen in gens @@ -1394,7 +1533,7 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: return kwa = gen.actual_kwargs() vals = [] - src_stat2kwarg = { v: k for k, v in gen.nominal_kwargs().items() } + src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} for n in sacs.keys(): src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' if src_stat in src_stat2kwarg: @@ -1402,9 +1541,16 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: if ak in kwa: vals.append(kwa[ak]) else: - logger.warning("actual_kwargs for %s does not report %s", gen.name(), ak) + logger.warning( + "actual_kwargs for %s does not report %s", gen.name(), ak + ) else: - logger.warning('nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', gen.name(), table_name, n) + logger.warning( + 'nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', + gen.name(), + table_name, + n, + ) select_q = self._get_aggregate_query([gen], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) @@ -1414,12 +1560,11 @@ def _get_column_data(self, count: int, to_str=repr): pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) with self.engine.connect() as connection: result = connection.execute( - text(f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}") + text( + f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" + ) ) - return [ - [to_str(x) for x in xs] - for xs in result.all() - ] + return [[to_str(x) for x in xs] for xs in result.all()] def do_propose(self, _arg): """ @@ -1433,10 +1578,7 @@ def do_propose(self, _arg): gens = self._get_generator_proposals() sample = self._get_column_data(limit) if sample: - rep = [ - x[0] if len(x) == 1 else ",".join(x) - for x in sample - ] + rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) else: self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) @@ -1455,11 +1597,11 @@ def do_propose(self, _arg): index=index + 1, name=gen.name(), fit=fit_s, - sample="; ".join(map(repr, gen.generate_data(limit))) + sample="; ".join(map(repr, gen.generate_data(limit))), ) def do_p(self, arg): - """ Synonym for propose """ + """Synonym for propose""" self.do_propose(arg) def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: @@ -1508,7 +1650,7 @@ def set_generator(self, gen: Generator): gen_info.gen = gen def do_s(self, arg): - """ Synonym for set """ + """Synonym for set""" self.do_set(arg) def do_unset(self, _arg): @@ -1519,7 +1661,7 @@ def do_unset(self, _arg): self._go_next() def do_merge(self, arg: str): - """ Add this column(s) to the specified column(s), so one generator covers them all. """ + """Add this column(s) to the specified column(s), so one generator covers them all.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") @@ -1527,10 +1669,10 @@ def do_merge(self, arg: str): if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return - cols_available = functools.reduce(lambda x, y: x | y, [ - frozenset(gen.columns) - for gen in table_entry.new_generators - ]) + cols_available = functools.reduce( + lambda x, y: x | y, + [frozenset(gen.columns) for gen in table_entry.new_generators], + ) cols_to_merge = frozenset(cols) unknown_cols = cols_to_merge - cols_available if unknown_cols: @@ -1585,7 +1727,7 @@ def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): ] def do_unmerge(self, arg: str): - """ Remove this column(s) from this generator, make them a separate generator. """ + """Remove this column(s) from this generator, make them a separate generator.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") @@ -1615,10 +1757,13 @@ def do_unmerge(self, arg: str): # The existing generator will not work gen_info.gen = None # And put them into a new (empty) generator - table_entry.new_generators.insert(self.generator_index + 1, GeneratorInfo( - columns=cols, - gen=None, - )) + table_entry.new_generators.insert( + self.generator_index + 1, + GeneratorInfo( + columns=cols, + gen=None, + ), + ) self.set_prompt() def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int): @@ -1650,7 +1795,11 @@ def update_config_generators( line_no += 1 if line: if len(line) != 3: - logger.error("line {0} of file {1} does not have three values", line_no, spec_path) + logger.error( + "line {0} of file {1} does not have three values", + line_no, + spec_path, + ) if gc.go_to(f"{line[0]}.{line[1]}"): gc.do_set(line[2]) gc.do_quit("yes") diff --git a/datafaker/main.py b/datafaker/main.py index 5b79831..c22f097 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,8 +1,8 @@ """Entrypoint for the datafaker package.""" import asyncio -from enum import Enum import json import sys +from enum import Enum from importlib import metadata from pathlib import Path from typing import Final, Optional @@ -15,8 +15,8 @@ from datafaker.create import create_db_data, create_db_tables, create_db_vocab from datafaker.dump import dump_db_tables from datafaker.interactive import ( - update_config_tables, update_config_generators, + update_config_tables, update_missingness, ) from datafaker.make import ( @@ -68,24 +68,24 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name, config: dict | None=None): +def load_metadata_config(orm_file_name, config: dict | None = None): with open(orm_file_name) as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) tables_dict = meta_dict.get("tables", {}) if config is not None and "tables" in config: # Remove ignored tables - for (name, table_config) in config.get("tables", {}).items(): + for name, table_config in config.get("tables", {}).items(): if get_flag(table_config, "ignore"): tables_dict.pop(name, None) return meta_dict -def load_metadata(orm_file_name, config: dict | None=None): +def load_metadata(orm_file_name, config: dict | None = None): meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) -def load_metadata_for_output(orm_file_name, config: dict | None=None): +def load_metadata_for_output(orm_file_name, config: dict | None = None): """ Load metadata excluding any foreign keys pointing to ignored tables. """ @@ -94,12 +94,9 @@ def load_metadata_for_output(orm_file_name, config: dict | None=None): @app.callback() -def main(verbose: bool = Option( - False, - "--verbose", - "-v", - help="Print more information." -)): +def main( + verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") +): conf_logger(verbose) @@ -108,7 +105,7 @@ def create_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option( DF_FILENAME, - help="The name of the generators file. Must be in the current working directory." + help="The name of the generators file. Must be in the current working directory.", ), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), num_passes: int = Option(1, help="Number of passes (rows or stories) to make"), @@ -145,7 +142,9 @@ def create_data( num_passes, ) logger.debug( - "Data created in %s %s.", num_passes, "pass" if num_passes == 1 else "passes" + "Data created in %s %s.", + num_passes, + "pass" if num_passes == 1 else "passes", ) for table_name, row_count in row_counts.items(): logger.debug( @@ -210,9 +209,11 @@ def create_generators( "Statistics file (output of make-stats); default is src-stats.yaml if the " "config file references SRC_STATS, or None otherwise." ), - show_default=False + show_default=False, + ), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing Python generators file." ), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing Python generators file."), ) -> None: """Make a datafaker file of generator classes. @@ -249,7 +250,12 @@ def create_generators( def make_vocab( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - force: bool = Option(False, "--force/--no-force", "-f/+f", help="Overwrite any existing vocabulary file."), + force: bool = Option( + False, + "--force/--no-force", + "-f/+f", + help="Overwrite any existing vocabulary file.", + ), compress: bool = Option(False, help="Compress file to .gz"), only: list[str] = Option([], help="Only download this table."), ) -> None: @@ -279,7 +285,9 @@ def make_stats( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: str = Option(STATS_FILENAME), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing vocabulary file."), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing vocabulary file." + ), ) -> None: """Compute summary statistics from the source database. @@ -309,9 +317,14 @@ def make_stats( @app.command() def make_tables( - config_file: Optional[str] = Option(None, help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables"), + config_file: Optional[str] = Option( + None, + help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables", + ), orm_file: str = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), - force: bool = Option(False, "--force", "-f", help="Overwrite any existing orm yaml file."), + force: bool = Option( + False, "--force", "-f", help="Overwrite any existing orm yaml file." + ), ) -> None: """Make a YAML file representing the tables in the schema. @@ -335,7 +348,9 @@ def make_tables( @app.command() def configure_tables( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path to write the configuration file to"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path to write the configuration file to" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ): """ @@ -347,10 +362,14 @@ def configure_tables( config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) # we don't pass config here so that no tables are ignored metadata = load_metadata(orm_file) - config_updated = update_config_tables(src_dsn, settings.src_schema, metadata, config) + config_updated = update_config_tables( + src_dsn, settings.src_schema, metadata, config + ) if config_updated is None: logger.debug("Cancelled") return @@ -361,7 +380,9 @@ def configure_tables( @app.command() def configure_missing( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path to write the configuration file to"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path to write the configuration file to" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ): """ @@ -373,7 +394,9 @@ def configure_missing( config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) if config_updated is None: @@ -386,9 +409,14 @@ def configure_missing( @app.command() def configure_generators( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path of the configuration file to alter"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path of the configuration file to alter" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - spec: Path = Option(None, help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively") + spec: Path = Option( + None, + help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", + ), ): """ Interactively set generators for column data. @@ -399,9 +427,13 @@ def configure_generators( config_file_path = Path(config_file) config = {} if config_file_path.exists(): - config = yaml.load(config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader) + config = yaml.load( + config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader + ) metadata = load_metadata(orm_file, config) - config_updated = update_config_generators(src_dsn, settings.src_schema, metadata, config, spec_path=spec) + config_updated = update_config_generators( + src_dsn, settings.src_schema, metadata, config, spec_path=spec + ) if config_updated is None: logger.debug("Cancelled") return @@ -412,12 +444,14 @@ def configure_generators( @app.command() def dump_data( - config_file: Optional[str] = Option(CONFIG_FILENAME, help="Path of the configuration file to alter"), + config_file: Optional[str] = Option( + CONFIG_FILENAME, help="Path of the configuration file to alter" + ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: str = Argument(help="The table to dump"), output: str | None = Option(None, help="output CSV file name"), ): - """ Dump a whole table as a CSV file (or to the console) from the destination database. """ + """Dump a whole table as a CSV file (or to the console) from the destination database.""" settings = get_settings() dst_dsn: str = settings.dst_dsn or "" assert dst_dsn != "", "Missing DST_DSN setting." @@ -427,7 +461,7 @@ def dump_data( if output == None: dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return - with open(output, 'wt', newline='') as out: + with open(output, "wt", newline="") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) @@ -452,7 +486,9 @@ def validate_config( def remove_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Truncate non-vocabulary tables in the destination schema.""" if yes: @@ -469,7 +505,9 @@ def remove_data( def remove_vocab( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Truncate vocabulary tables in the destination schema.""" if yes: @@ -487,7 +525,9 @@ def remove_vocab( def remove_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - yes: bool = Option(False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first"), + yes: bool = Option( + False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" + ), ) -> None: """Drop all tables in the destination schema. diff --git a/datafaker/make.py b/datafaker/make.py index 1284672..17e0d3b 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -5,13 +5,11 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import ( - Any, Callable, Final, Mapping, Optional, Sequence, Tuple -) -import yaml +from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple import pandas as pd import snsql +import yaml from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider @@ -26,8 +24,8 @@ from datafaker.utils import ( create_db_engine, download_table, - get_property, get_flag, + get_property, get_related_table_names, get_sync_engine, get_vocabulary_table_names, @@ -73,7 +71,8 @@ class RowGeneratorInfo: @dataclass class ColumnChoice: - """ Chooses columns based on a random number in [0,1) """ + """Chooses columns based on a random number in [0,1)""" + function_name: str argument_values: list[str] @@ -84,10 +83,7 @@ def make_column_choices( return [ ColumnChoice( function_name=mg["name"], - argument_values=[ - f"{k}={v}" - for k, v in mg.get("kwargs", {}).items() - ] + argument_values=[f"{k}={v}" for k, v in mg.get("kwargs", {}).items()], ) for mg in table_config.get("missingness_generators", []) if "name" in mg @@ -122,7 +118,9 @@ def _render_value(v) -> str: if type(v) is set: return "{" + ", ".join(_render_value(x) for x in v) + "}" if type(v) is dict: - return "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" + return ( + "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" + ) if type(v) is str: return v return str(v) @@ -181,9 +179,7 @@ def _get_row_generator( return row_gen_info, columns_covered -def _get_default_generator( - column: Column -) -> RowGeneratorInfo: +def _get_default_generator(column: Column) -> RowGeneratorInfo: """Get default generator information, for the given column.""" # If it's a primary key column, we presume that primary keys are populated # automatically. @@ -215,7 +211,8 @@ def _get_default_generator( primary_key=column.primary_key, variable_names=variable_names, function_call=_get_function_call( - function_name=generator_function, positional_arguments=generator_arguments + function_name=generator_function, + positional_arguments=generator_arguments, ), ) @@ -243,10 +240,13 @@ def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: column_type = column.type if column_type.scale is None: return ("generic.numeric.float_number", {}) - return ("generic.numeric.float_number", { - "start": 0, - "end": 10 ** column_type.scale - 1, - }) + return ( + "generic.numeric.float_number", + { + "start": 0, + "end": 10**column_type.scale - 1, + }, + ) def _string_generator(column: Column) -> tuple[str, dict[str, str]]: @@ -257,7 +257,8 @@ def _string_generator(column: Column) -> tuple[str, dict[str, str]]: column_size: Optional[int] = getattr(column.type, "length", None) if column_size is None: return ("generic.text.color", {}) - return ("generic.person.password", { "length": str(column_size) }) + return ("generic.person.password", {"length": str(column_size)}) + def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ @@ -265,10 +266,13 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ if not column.primary_key: return ("generic.numeric.integer_number", {}) - return ("generic.column_value_provider.increment", { - "db_connection": "dst_db_conn", - "column": f'metadata.tables["{column.table.name}"].columns["{column.name}"]', - }) + return ( + "generic.column_value_provider.increment", + { + "db_connection": "dst_db_conn", + "column": f'metadata.tables["{column.table.name}"].columns["{column.name}"]', + }, + ) _YEAR_SUMMARY_QUERY = ( @@ -316,12 +320,12 @@ def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: sqltypes.Date: GeneratorInfo( generator="generic.datetime.date", summary_query=_YEAR_SUMMARY_QUERY, - arg_types={ "start": int, "end": int } + arg_types={"start": int, "end": int}, ), sqltypes.DateTime: GeneratorInfo( generator="generic.datetime.datetime", summary_query=_YEAR_SUMMARY_QUERY, - arg_types={ "start": int, "end": int } + arg_types={"start": int, "end": int}, ), sqltypes.Integer: GeneratorInfo( # must be before Numeric generator=_integer_generator, @@ -345,7 +349,7 @@ def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: sqltypes.String: GeneratorInfo( generator=_string_generator, choice=True, - ) + ), } @@ -358,7 +362,7 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: callable, dict of keyword arguments to pass to the callable). """ if column_t in _COLUMN_TYPE_TO_GENERATOR_INFO: - return _COLUMN_TYPE_TO_GENERATOR_INFO[column_t] + return _COLUMN_TYPE_TO_GENERATOR_INFO[column_t] # Search exhaustively for a superclass to the columns actual type for key, value in _COLUMN_TYPE_TO_GENERATOR_INFO.items(): @@ -368,8 +372,9 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: return None -def _get_generator_for_column(column_t: type) -> str | Callable[ - [type_api.TypeEngine], tuple[str, dict[str, str]]]: +def _get_generator_for_column( + column_t: type, +) -> str | Callable[[type_api.TypeEngine], tuple[str, dict[str, str]]]: """ Gets a generator from a column type. @@ -392,7 +397,7 @@ def _get_generator_and_arguments(column: Column) -> tuple[str, dict[str, str]]: generator_arguments: dict[str, str] = {} if callable(generator_function): (generator_function, generator_arguments) = generator_function(column) - return generator_function,generator_arguments + return generator_function, generator_arguments def _get_provider_for_column(column: Column) -> Tuple[list[str], str, dict[str, str]]: @@ -443,6 +448,7 @@ class _PrimaryConstraint: columns in a table comprise the primary key. Not a real constraint, but enough to write df.py. """ + def __init__(self, *columns: Column, name: str): self.name = name self.columns = columns @@ -461,15 +467,11 @@ def _get_generator_for_table( ), key=_constraint_sort_key, ) - primary_keys = [ - c for c in table.columns - if c.primary_key - ] + primary_keys = [c for c in table.columns if c.primary_key] if 1 < len(primary_keys): - unique_constraints.append(_PrimaryConstraint( - *primary_keys, - name=f"{table.name}_primary_key" - )) + unique_constraints.append( + _PrimaryConstraint(*primary_keys, name=f"{table.name}_primary_key") + ) column_choices = make_column_choices(table_config) if column_choices: nonnull_columns = { @@ -522,7 +524,7 @@ def make_vocabulary_tables( config: Mapping, overwrite_files: bool, compress: bool, - table_names: set[str] | None=None, + table_names: set[str] | None = None, ): """ Extracts the data from the source database for each @@ -539,7 +541,10 @@ def make_vocabulary_tables( else: invalid_names = table_names - vocab_names if invalid_names: - logger.error("The following names are not the names of vocabulary tables: %s", invalid_names) + logger.error( + "The following names are not the names of vocabulary tables: %s", + invalid_names, + ) logger.info("Valid names are: %s", vocab_names) return for table_name in table_names: @@ -584,7 +589,7 @@ def make_table_generators( # pylint: disable=too-many-locals tables: list[TableGeneratorInfo] = [] vocabulary_tables: list[VocabularyTableGeneratorInfo] = [] vocab_names = get_vocabulary_table_names(config) - for (table_name, table) in metadata.tables.items(): + for table_name, table in metadata.tables.items(): if table_name in vocab_names: related = get_related_table_names(table) related_non_vocab = related.difference(vocab_names) @@ -593,16 +598,18 @@ def make_table_generators( # pylint: disable=too-many-locals "Making table '%s' a vocabulary table requires that also the" " related tables (%s) be also vocabulary tables.", table.name, - related_non_vocab + related_non_vocab, ) vocabulary_tables.append( _get_generator_for_existing_vocabulary_table(table) ) else: - tables.append(_get_generator_for_table( - tables_config.get(table.name, {}), - table, - )) + tables.append( + _get_generator_for_table( + tables_config.get(table.name, {}), + table, + ) + ) story_generators = _get_story_generators(config) @@ -766,9 +773,7 @@ def fix_type(value): def fix_types(dics): - return [{ - k: fix_type(v) for k, v in dic.items() - } for dic in dics] + return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] async def make_src_stats( @@ -793,7 +798,10 @@ async def make_src_stats( async with DbConnection(engine) as db_conn: return await make_src_stats_connection(config, db_conn, metadata) -async def make_src_stats_connection(config: Mapping, db_conn: DbConnection, metadata: MetaData): + +async def make_src_stats_connection( + config: Mapping, db_conn: DbConnection, metadata: MetaData +): date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/providers.py b/datafaker/providers.py index b07f2b7..1ebd5bf 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -5,8 +5,8 @@ from mimesis import Datetime, Text from mimesis.providers.base import BaseDataProvider, BaseProvider -from sqlalchemy import Connection, Column -from sqlalchemy.sql import functions, select, func +from sqlalchemy import Column, Connection +from sqlalchemy.sql import func, functions, select class ColumnValueProvider(BaseProvider): @@ -29,12 +29,12 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed = None, **kwargs): + def __init__(self, *, seed=None, **kwargs): super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} def increment(self, db_connection: Connection, column: Column) -> int: - """ Return incrementing value for the column specified. """ + """Return incrementing value for the column specified.""" name = f"{column.table.name}.{column.name}" result = self.accumulators.get(name, None) if result == None: diff --git a/datafaker/remove.py b/datafaker/remove.py index c0a6c47..a316619 100644 --- a/datafaker/remove.py +++ b/datafaker/remove.py @@ -1,7 +1,7 @@ """Functions and classes to undo the operations in create.py.""" from typing import Any, Mapping -from sqlalchemy import delete, MetaData +from sqlalchemy import MetaData, delete from datafaker.settings import get_settings from datafaker.utils import ( @@ -9,23 +9,18 @@ get_sync_engine, get_vocabulary_table_names, logger, - remove_vocab_foreign_key_constraints, reinstate_vocab_foreign_key_constraints, + remove_vocab_foreign_key_constraints, sorted_non_vocabulary_tables, ) -def remove_db_data( - metadata: MetaData, config: Mapping[str, Any] -) -> None: +def remove_db_data(metadata: MetaData, config: Mapping[str, Any]) -> None: """Truncate the synthetic data tables but not the vocabularies.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" remove_db_data_from( - metadata, - config, - settings.dst_dsn, - schema_name=settings.dst_schema + metadata, config, settings.dst_dsn, schema_name=settings.dst_schema ) @@ -33,9 +28,7 @@ def remove_db_data_from( metadata: MetaData, config: Mapping[str, Any], db_dsn: str, schema_name: str | None ) -> None: """Truncate the synthetic data tables but not the vocabularies.""" - dst_engine = get_sync_engine( - create_db_engine(db_dsn, schema_name=schema_name) - ) + dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) with dst_engine.connect() as dst_conn: for table in reversed(sorted_non_vocabulary_tables(metadata, config)): @@ -44,7 +37,9 @@ def remove_db_data_from( dst_conn.commit() -def remove_db_vocab(metadata: MetaData, meta_dict: Mapping[str, Any], config: Mapping[str, Any]) -> None: +def remove_db_vocab( + metadata: MetaData, meta_dict: Mapping[str, Any], config: Mapping[str, Any] +) -> None: """Truncate the vocabulary tables.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 51f4b03..303c2c7 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,14 +1,16 @@ +from typing import Callable + import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql -from sqlalchemy.sql import sqltypes, schema -from typing import Callable +from sqlalchemy.sql import schema, sqltypes from datafaker.utils import make_foreign_key_name table_component_t = dict[str, any] table_t = dict[str, table_component_t] + def simple(type_): """ Parses a simple sqltypes type. @@ -17,29 +19,34 @@ def simple(type_): """ return parsy.string(type_.__name__).result(type_) + def integer(): """ Parses an integer, outputting that integer. """ return parsy.regex(r"-?[0-9]+").map(int) + def integer_arguments(): """ Parses a list of integers. The integers are surrounded by brackets and separated by a comma and space. """ - return parsy.string("(") >> ( - integer().sep_by(parsy.string(", ")) - ) << parsy.string(")") + return ( + parsy.string("(") >> (integer().sep_by(parsy.string(", "))) << parsy.string(")") + ) + def numeric_type(type_): """ Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. """ - return parsy.string(type_.__name__ - ) >> integer_arguments().optional([]).combine(type_) + return parsy.string(type_.__name__) >> integer_arguments().optional([]).combine( + type_ + ) + def string_type(type_): @parsy.generate(type_.__name__) @@ -56,8 +63,10 @@ def st_parser(): parsy.string(' COLLATE "') >> parsy.regex(r'[^"]*') << parsy.string('"') ).optional() return type_(length=length, collation=collation) + return st_parser + def time_type(type_, pg_type): @parsy.generate(type_.__name__) def pgt_parser(): @@ -70,18 +79,22 @@ def pgt_parser(): parsy.string("(") >> integer() << parsy.string(")") ).optional() timezone: str | None = yield ( - parsy.string(" WITH") >> ( - parsy.string(" ").result(True) | parsy.string("OUT ").result(False) - ) << parsy.string("TIME ZONE") + parsy.string(" WITH") + >> (parsy.string(" ").result(True) | parsy.string("OUT ").result(False)) + << parsy.string("TIME ZONE") ).optional(False) if precision is None and not timezone: # normal sql type return type_ return pg_type(precision=precision, timezone=timezone) + return pgt_parser + SIMPLE_TYPE_PARSER = parsy.alt( - parsy.string("DOUBLE PRECISION").result(sqltypes.DOUBLE_PRECISION), # must be before DOUBLE + parsy.string("DOUBLE PRECISION").result( + sqltypes.DOUBLE_PRECISION + ), # must be before DOUBLE simple(sqltypes.FLOAT), simple(sqltypes.DOUBLE), simple(sqltypes.INTEGER), @@ -110,6 +123,7 @@ def pgt_parser(): time_type(sqltypes.TIME, postgresql.types.TIME), ) + @parsy.generate def type_parser(): base = yield SIMPLE_TYPE_PARSER @@ -118,6 +132,7 @@ def type_parser(): return base return postgresql.ARRAY(base, dimensions=dimensions) + def column_to_dict(column: Column, dialect: Dialect) -> str: type_ = column.type if isinstance(type_, postgresql.DOMAIN): @@ -139,6 +154,7 @@ def column_to_dict(column: Column, dialect: Dialect) -> str: result["foreign_keys"] = foreign_keys return result + def dict_to_column( table_name, col_name, @@ -156,7 +172,7 @@ def dict_to_column( ForeignKey( fk, name=make_foreign_key_name(table_name, col_name), - ondelete='CASCADE', + ondelete="CASCADE", ) for fk in rep["foreign_keys"] if not ignore_fk(fk) @@ -171,21 +187,18 @@ def dict_to_column( nullable=rep.get("nullable", None), ) + def dict_to_unique(rep: dict) -> schema.UniqueConstraint: - return schema.UniqueConstraint( - *rep.get("columns", []), - name=rep.get("name", None) - ) + return schema.UniqueConstraint(*rep.get("columns", []), name=rep.get("name", None)) + def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: return { "name": constraint.name, - "columns": [ - str(col.name) - for col in constraint.columns - ] + "columns": [str(col.name) for col in constraint.columns], } + def table_to_dict(table: Table, dialect: Dialect) -> table_t: """ Converts a SQL Alchemy Table object into a @@ -203,6 +216,7 @@ def table_to_dict(table: Table, dialect: Dialect) -> table_t: ], } + def dict_to_table( name: str, meta: MetaData, @@ -212,15 +226,17 @@ def dict_to_table( return Table( name, meta, - *[ dict_to_column(name, colname, col, ignore_fk) + *[ + dict_to_column(name, colname, col, ignore_fk) for (colname, col) in table_dict.get("columns", {}).items() ], - *[ dict_to_unique(constraint) - for constraint in table_dict.get("unique", []) - ], + *[dict_to_unique(constraint) for constraint in table_dict.get("unique", [])], ) -def metadata_to_dict(meta: MetaData, schema_name: str | None, engine: Engine) -> dict[str, table_t]: + +def metadata_to_dict( + meta: MetaData, schema_name: str | None, engine: Engine +) -> dict[str, table_t]: """ Converts a SQL Alchemy MetaData object into a Python object ready for conversion to YAML. @@ -248,10 +264,7 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): return tables_dict[fk_bits[0]].get("ignore", False) -def dict_to_metadata( - obj: dict, - config_for_output: dict=None -) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict = None) -> MetaData: """ Converts a dict to a SQL Alchemy MetaData object. @@ -268,6 +281,6 @@ def dict_to_metadata( else: ignore_fk = lambda _: False meta = MetaData() - for (k, td) in tables_dict.items(): + for k, td in tables_dict.items(): dict_to_table(k, meta, td, ignore_fk) return meta diff --git a/datafaker/utils.py b/datafaker/utils.py index 33b8c84..883e096 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -1,27 +1,20 @@ """Utility functions.""" import ast +import gzip +import importlib.util import json import logging import sys -import importlib.util from pathlib import Path from types import ModuleType from typing import Any, Final, Mapping, Optional, Union -import gzip +import sqlalchemy import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate from psycopg2.errors import UndefinedObject -import sqlalchemy -from sqlalchemy import ( - Connection, - Engine, - ForeignKey, - create_engine, - event, - select, -) +from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.exc import IntegrityError, ProgrammingError from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -96,10 +89,15 @@ def open_compressed_file(file_name): def table_row_count(table: Table, conn: Connection) -> int: return conn.execute( - select(sqlalchemy.func.count()).select_from(sqlalchemy.table( - table.name, - *[sqlalchemy.column(col.name) for col in table.primary_key.columns.values()], - )) + select(sqlalchemy.func.count()).select_from( + sqlalchemy.table( + table.name, + *[ + sqlalchemy.column(col.name) + for col in table.primary_key.columns.values() + ], + ) + ) ).scalar_one() @@ -117,10 +115,7 @@ def download_table( rowcount = table_row_count(table, conn) count = 0 for row in conn.execute(stmt).mappings(): - result = { - str(col_name): value - for (col_name, value) in row.items() - } + result = {str(col_name): value for (col_name, value) in row.items()} yamlfile.write(yaml.dump([result]).encode()) count += 1 if count % MAKE_VOCAB_PROGRESS_REPORT_EVERY == 0: @@ -128,7 +123,7 @@ def download_table( "written row %d of %d, %.1f%%", count, rowcount, - 100*count/rowcount, + 100 * count / rowcount, ) @@ -213,6 +208,7 @@ class StdoutHandler(logging.Handler): A handler that writes to stdout. We aren't using StreamHandler because that confuses typer.testing.CliRunner """ + def flush(self): self.acquire() try: @@ -236,6 +232,7 @@ class StderrHandler(logging.Handler): A handler that writes to stderr. We aren't using StreamHandler because that confuses typer.testing.CliRunner """ + def flush(self): self.acquire() try: @@ -276,8 +273,8 @@ def conf_logger(verbose: bool) -> None: handlers=[stdout_handler, stderr_handler], force=True, ) - logging.getLogger('asyncio').setLevel(logging.WARNING) - logging.getLogger('blib2to3.pgen2.driver').setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) def get_flag(maybe_dict, key): @@ -370,7 +367,9 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] for fk in vocab_table.foreign_key_constraints: - logger.debug("Dropping constraint %s from table %s", fk.name, vocab_table_name) + logger.debug( + "Dropping constraint %s from table %s", fk.name, vocab_table_name + ) with Session(dst_engine) as session: session.begin() try: @@ -378,7 +377,11 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): session.commit() except IntegrityError: session.rollback() - logger.exception("Dropping table %s key constraint %s failed:", vocab_table_name, fk.name) + logger.exception( + "Dropping table %s key constraint %s failed:", + vocab_table_name, + fk.name, + ) except ProgrammingError as e: session.rollback() if type(e.orig) is UndefinedObject: @@ -392,7 +395,9 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] try: - for (column_name, column_dict) in meta_dict["tables"][vocab_table_name]["columns"].items(): + for column_name, column_dict in meta_dict["tables"][vocab_table_name][ + "columns" + ].items(): fk_targets = column_dict.get("foreign_keys", []) if fk_targets: fk = ForeignKeyConstraint( @@ -407,7 +412,9 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng session.execute(AddConstraint(fk)) session.commit() except IntegrityError: - logger.exception("Restoring table %s foreign keys failed:", vocab_table_name) + logger.exception( + "Restoring table %s foreign keys failed:", vocab_table_name + ) def stream_yaml(yaml_file_handle): @@ -437,7 +444,7 @@ def topological_sort(input_nodes, get_dependencies_fn): Topoligically sort input_nodes and find any cycles. Returns a pair (sorted, cycles). - + 'sorted' is a list of all the elements of input_nodes sorted so that dependencies returned by get_dependencies_fn come after nodes that depend on them. Cycles are @@ -478,23 +485,21 @@ def topological_sort(input_nodes, get_dependencies_fn): elif n in grey: # n is in a cycle cycle_start = grey.index(n) - cycles.append(grey[cycle_start:len(grey)]) + cycles.append(grey[cycle_start : len(grey)]) return (black, cycles) def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Table]: - table_names = set( - metadata.tables.keys() - ).difference( + table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) (sorted, cycles) = topological_sort( - table_names, - lambda tn: get_related_table_names(metadata.tables[tn]) + table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: logger.warning(f"Cycle detected between tables: {cycle}") - return [ metadata.tables[tn] for tn in sorted ] + return [metadata.tables[tn] for tn in sorted] + def generators_require_stats(config: Mapping) -> bool: """ @@ -527,14 +532,16 @@ def generators_require_stats(config: Mapping) -> bool: if any(name == "SRC_STATS" for name in names): stats_required = True except SyntaxError as e: - errors.append(( - "Syntax error in argument %d of %s: %s\n%s\n%s", - n + 1, - where, - e.msg, - arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), - )) + errors.append( + ( + "Syntax error in argument %d of %s: %s\n%s\n%s", + n + 1, + where, + e.msg, + arg, + " " * e.offset + "^" * max(1, e.end_offset - e.offset), + ) + ) for k, arg in call.get("kwargs", {}).items(): if type(arg) is str: try: @@ -546,14 +553,16 @@ def generators_require_stats(config: Mapping) -> bool: if any(name == "SRC_STATS" for name in names): stats_required = True except SyntaxError as e: - errors.append(( - "Syntax error in argument %s of %s: %s\n%s\n%s", - k, - where, - e.msg, - arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), - )) + errors.append( + ( + "Syntax error in argument %s of %s: %s\n%s\n%s", + k, + where, + e.msg, + arg, + " " * e.offset + "^" * max(1, e.end_offset - e.offset), + ) + ) for error in errors: logger.error(*error) return stats_required diff --git a/docs/source/_static/config_schema.html b/docs/source/_static/config_schema.html index ca0baa0..e78949f 100644 --- a/docs/source/_static/config_schema.html +++ b/docs/source/_static/config_schema.html @@ -1 +1,2759 @@ - Datafaker Config

datafaker Config

Type: object

A datafaker configuration YAML file

No Additional Properties

Type: boolean

Run source-statistics queries using asyncpg.

Type: string

The name of a local Python module of row generators (excluding .py).

Type: string

The name of a local Python module of story generators (excluding .py).

Type: array

An array of source statistics queries.

Each item of this array must be:

Type: object
No Additional Properties

Type: string

A name for the query, which will be used in the stats file.

Type: string

A SQL query.

Type: string

A SmartNoise SQL query.

Type: number

The differential privacy epsilon value for the DP query.

Type: number

The differential privacy delta value for the DP query.

Type: object

See https://docs.smartnoise.org/sql/metadata.html#yaml-format.

All properties whose name matches the following regular expression must respect the following conditions

Property name regular expression: ^(?!(max_ids|row_privacy|sample_max_ids|censor_dims|clamp_counts|clamp_columns|use_dpsu)).*$
Type: object
No Additional Properties

Type: array of object

An array of story generators.

Each item of this array must be:

Type: object
No Additional Properties

Type: string

The full name of a story generator (e.g. mystorygenerators.short_story).

Type: array

Positional arguments to pass to the story generator.

Type: object

Keyword arguments to pass to the story generator.

Type: integer

The number of times to call the story generator per pass.

Type: integer

The maximum number of tries to respect a uniqueness constraint.

Type: object

Table configurations.

All properties whose name matches the following regular expression must respect the following conditions

Property name regular expression: .*
Type: object

A table configuration.

No Additional Properties

Type: boolean

Whether to completely ignore this table.

Type: boolean

Whether to export the table data.

Type: integer

The number of rows to generate per pass.

Type: array of object

An array of row generators to create column values.

Each item of this array must be:

Type: object

Type: string

The name of a (built-in or custom) function (e.g. max or myrowgenerators.my_gen).

Type: array

Positional arguments to pass to the function.

Type: object

Keyword arguments to pass to the function.

Type: array of string or string

One or more columns to assign the return value to.

Each item of this array must be:

\ No newline at end of file + + + + + + + + + + + + + + + + datafaker Config + + + +

datafaker Config

Type: object
+

A datafaker configuration YAML file

+
No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Run source-statistics queries using asyncpg.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a local Python module of row generators (excluding .py).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a local Python module of story generators (excluding .py).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Objects that need to be instantiated from the row and story generators modules.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

An array of source statistics queries.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A name for the query, which will be used in the stats file.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string
+

Comments to be copied into the src-stats.yaml file describing the query results.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A SQL query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: string
+

A SmartNoise SQL query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: number
+

The differential privacy epsilon value for the DP query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: number
+

The differential privacy delta value for the DP query.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

See https://docs.smartnoise.org/sql/metadata.html#yaml-format.

+
+ + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: integer
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+ + + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+

+ +

+

All properties whose name matches the following regular expression must respect the following conditions

+ Property name regular expression: ^(?!(max_ids|row_privacy|sample_max_ids|censor_dims|clamp_counts|clamp_columns|use_dpsu)).*$ +
+ + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+

+ +

+
+ + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

An array of story generators.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The full name of a story generator (e.g. mystorygenerators.short_story).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

Positional arguments to pass to the story generator.

+
+ + + + + + No Additional Items +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the story generator.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The number of times to call the story generator per pass.

+
+ + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The maximum number of tries to respect a uniqueness constraint.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Table configurations.

+
+ + + + + + +
+
+
+

+ +

+
+ +
+

+ +

+

All properties whose name matches the following regular expression must respect the following conditions

+ Property name regular expression: .* +
+ + Type: object
+

A table configuration.

+
No Additional Properties + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether to completely ignore this table.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether to export the table data.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: boolean
+

Whether the table is a Primary Private table (perhaps a table of patients).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: integer
+

The number of rows to generate per pass.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

An array of row generators to create column values.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a (built-in or custom) function (e.g. max or myrowgenerators.my_gen).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array
+

Positional arguments to pass to the function.

+
+ + + + + + No Additional Items +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the function.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string or string
+

One or more columns to assign the return value to.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of object
+

Function to generate a set of nullable columns that should not be null

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: object
+ + + + + + + +
+
+
+

+ +

+
+ +
+
+ + Type: string
+

The name of a (built-in or custom) function (e.g. column_presence.sampled).

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: object
+

Keyword arguments to pass to the function.

+
+ + + + + + +
+
+
+
+
+
+
+

+ +

+
+ +
+
+ + Type: array of string
+

Column names that might be returned.

+
+ + + + + + No Additional Items

Each item of this array must be:

+
+
+ + + Type: string
+ + + + + + + +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + \ No newline at end of file diff --git a/docs/source/custom_generators.rst b/docs/source/custom_generators.rst index c29e340..73d04b6 100644 --- a/docs/source/custom_generators.rst +++ b/docs/source/custom_generators.rst @@ -60,4 +60,4 @@ Again, you must define your own; ``datafaker`` provides no built-in story genera You can put your story generators in their own Python file, or you can re-use your row generators file if you like. -A story generator is a Python Generator function (a function that calls ``yield`` to return multiple values rather than ``return`` a single one). \ No newline at end of file +A story generator is a Python Generator function (a function that calls ``yield`` to return multiple values rather than ``return`` a single one). diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst index c588c0b..8cf833a 100644 --- a/docs/source/introduction.rst +++ b/docs/source/introduction.rst @@ -124,7 +124,7 @@ Some of these functions take arguments, that we can assign like this: Anyway, we now need to remake the generators (``create-generators``) and re-run them (``create-data``): .. code-block:: console - + $ datafaker create-generators --force $ datafaker create-data --num-passes 15 diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 48e52a0..43722aa 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -110,7 +110,7 @@ This command will start an interactive command shell. Don't be intimidated, just columns data help next peek private select vocabulary counts empty ignore generate previous quit tables - (table: myfirsttable) + (table: myfirsttable) You can also get help for any of the commands listed; for example to see help for the ``vocabulary`` command type ``? vocabulary`` or ``help vocabulary``: @@ -134,7 +134,7 @@ Press the Tab key again to see these options: .. code-block:: console (table: actor) help p - peek previous private + peek previous private (table: actor) help p Now you can continue with r-i-tab to get ``private``, r-e-tab to get ``previous`` or e-tab to get ``peek``. This can be very useful; try pressing Tab twice on an empty line to see quickly all the possible commands, for example! @@ -372,7 +372,7 @@ To describe "null-partitioned grouped", let us make the generator much more comp | None | None | None | Pencil on tracing paper | | None | 18.5 | 24.3 | Lithograph from an illustrated book of poems and four lithographs | +----------+---------------+---------------+------------------------------------------------------------------------------------------------+ - (artwork.depth_cm,width_cm,height_cm,medium) + (artwork.depth_cm,width_cm,height_cm,medium) Here we can see that Moma understandably does not record depths for 2D artworks so we have many NULLs in that column. If we try to apply the standard normal or lognormal to data with many NULLs, it will ignore those rows with any NULLs. diff --git a/tests/test_create.py b/tests/test_create.py index 0fe1bf3..333c01a 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,9 +1,9 @@ """Tests for the create module.""" import itertools as itt -from collections import Counter import os -from pathlib import Path import random +from collections import Counter +from pathlib import Path from typing import Any, Generator, Tuple from unittest.mock import MagicMock, call, patch @@ -11,23 +11,26 @@ from sqlalchemy.schema import Table from datafaker.base import TableGenerator -from datafaker.create import ( - create_db_vocab, - populate, -) +from datafaker.create import create_db_vocab, populate from datafaker.remove import remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict from tests.utils import DatafakerTestCase, GeneratesDBTestCase + class TestCreate(GeneratesDBTestCase): """Test the make_table_generators function.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def test_create_vocab(self) -> None: """Test the create_db_vocab function.""" - with patch.dict(os.environ, {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, clear=True): + with patch.dict( + os.environ, + {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, + clear=True, + ): config = { "tables": { "player": { @@ -83,7 +86,7 @@ def test_make_table_generators(self) -> None: class TestPopulate(DatafakerTestCase): - """ Test create.populate. """ + """Test create.populate.""" def test_populate(self) -> None: """Test the populate function.""" diff --git a/tests/test_dump.py b/tests/test_dump.py index 4293f28..2d5ed26 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -1,27 +1,32 @@ """Tests for the base module.""" -from sqlalchemy.schema import MetaData -from tests.utils import RequiresDBTestCase from unittest.mock import MagicMock, call, patch +from sqlalchemy.schema import MetaData + from datafaker.dump import dump_db_tables +from tests.utils import RequiresDBTestCase + class DumpTests(RequiresDBTestCase): """Testing configure-tables.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: - """ Test dump-data. """ + """Test dump-data.""" TEST_OUTPUT_FILE = "test_output_file_object" metadata = MetaData() metadata.reflect(self.engine) dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) - make_csv_writer.assert_has_calls([ - call().writerow(["id", "given_name", "family_name"]), - call().writerow((1, 'Mark', 'Samson')), - call().writerow((2, 'Tim', 'Friedman')), - call().writerow((3, 'Pierre', 'Marchmont')), - ]) + make_csv_writer.assert_has_calls( + [ + call().writerow(["id", "given_name", "family_name"]), + call().writerow((1, "Mark", "Samson")), + call().writerow((2, "Tim", "Friedman")), + call().writerow((3, "Pierre", "Marchmont")), + ] + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 00f4547..418b4f9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -6,14 +6,15 @@ from sqlalchemy import create_engine, inspect from typer.testing import CliRunner -from tests.utils import RequiresDBTestCase - from datafaker.main import app +from tests.utils import RequiresDBTestCase # pylint: disable=subprocess-run-check + class DBFunctionalTestCase(RequiresDBTestCase): """End-to-end tests that require a database.""" + dump_file_path = "src.dump" database_name = "src" schema_name = "public" @@ -33,7 +34,7 @@ class DBFunctionalTestCase(RequiresDBTestCase): generator_file_paths = tuple( map(Path, ("story_generators.py", "row_generators.py")), ) - #dump_file_path = Path("dst.dump") + # dump_file_path = Path("dst.dump") config_file_path = Path("example_config2.yaml") stats_file_path = Path("example_stats.yaml") @@ -430,7 +431,7 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args, expected_error: str=None, env={}): + def invoke(self, *args, expected_error: str = None, env={}): res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) @@ -513,12 +514,15 @@ def test_unique_constraint_fail(self) -> None: ) self.assertEqual("", completed_process.stderr) self.assertEqual( - ("Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.short_story'\n" - "Generating data for story 'story_generators.full_row_story'\n" - "Generating data for story 'story_generators.long_story'\n" - "Generating data for story 'story_generators.long_story'\n") * 3, + ( + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.short_story'\n" + "Generating data for story 'story_generators.full_row_story'\n" + "Generating data for story 'story_generators.long_story'\n" + "Generating data for story 'story_generators.long_story'\n" + ) + * 3, completed_process.stdout, ) @@ -529,7 +533,7 @@ def test_unique_constraint_fail(self) -> None: f"--orm-file={self.alt_orm_file_path}", f"--df-file={self.alt_datafaker_file_path}", "--num-passes=1", - expected_error = ( + expected_error=( "Failed to satisfy unique constraints for table unique_constraint_test" ), ) @@ -538,7 +542,7 @@ def test_unique_constraint_fail(self) -> None: def test_create_schema(self) -> None: """Check that we create a destination schema if it doesn't exist.""" - env = { "dst_schema": "doesntexistyetschema" } + env = {"dst_schema": "doesntexistyetschema"} engine = create_engine(self.env["dst_dsn"]) inspector = inspect(engine) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 386aeaa..9487284 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1,58 +1,66 @@ """ Tests for the base module. """ import copy -from dataclasses import dataclass import random import re +from dataclasses import dataclass +from unittest.mock import MagicMock, Mock, patch + from sqlalchemy import insert, select +from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.interactive import ( DbCmd, - TableCmd, GeneratorCmd, MissingnessCmd, + TableCmd, update_config_generators, ) -from datafaker.generators import NullPartitionedNormalGeneratorFactory - -from tests.utils import RequiresDBTestCase, GeneratesDBTestCase -from unittest.mock import MagicMock, Mock, patch +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase class TestDbCmdMixin(DbCmd): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.reset() + def reset(self): self.messages: list[tuple[str, list, dict[str, any]]] = [] self.headings: list[str] = [] self.rows: list[list[str]] = [] self.column_items: list[str] = [] self.columns: dict[str, list[str]] = {} + def print(self, text: str, *args, **kwargs): self.messages.append((text, args, kwargs)) + def print_table(self, headings: list[str], rows: list[list[str]]): self.headings = headings self.rows = rows + def print_table_by_columns(self, columns: dict[str, list[str]]): self.columns = columns + def columnize(self, items: list[str]): self.column_items.append(items) + def ask_save(self) -> str: return "yes" class TestTableCmd(TableCmd, TestDbCmdMixin): - """ TableCmd but mocked """ + """TableCmd but mocked""" class ConfigureTablesTests(RequiresDBTestCase): """Testing configure-tables.""" + def _get_cmd(self, config) -> TestTableCmd: return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) class ConfigureTablesSrcTests(ConfigureTablesTests): """Testing configure-tables with src.dump.""" + dump_file_path = "src.dump" database_name = "src" schema_name = "public" @@ -70,11 +78,15 @@ def test_table_name_prompts(self) -> None: for t in reversed(table_names): self.assertIn(t, tc.prompt) tc.do_previous("") - self.assertListEqual(tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})]) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] + ) tc.reset() bad_table_name = "notarealtable" tc.do_next(bad_table_name) - self.assertListEqual(tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})]) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] + ) tc.reset() good_table_name = table_names[2] tc.do_next(good_table_name) @@ -107,9 +119,13 @@ def test_null_configuration(self) -> None: tc.do_private("") tc.do_quit("") tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue(tables["unique_constraint_test"].get("primary_private", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) def test_null_table_configuration(self) -> None: """A table still works if its configuration is None.""" @@ -123,9 +139,13 @@ def test_null_table_configuration(self) -> None: tc.do_private("") tc.do_quit("") tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue(tables["unique_constraint_test"].get("primary_private", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) def test_configure_tables(self) -> None: """Test that we can change columns to ignore, vocab or generate.""" @@ -142,7 +162,7 @@ def test_configure_tables(self) -> None: }, "empty_vocabulary": { "private": True, - } + }, }, } with self._get_cmd(config) as tc: @@ -159,9 +179,13 @@ def test_configure_tables(self) -> None: tc.do_empty("") tc.do_quit("") tables = tc.config["tables"] - self.assertFalse(tables["unique_constraint_test"].get("vocabulary_table", False)) + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertFalse(tables["unique_constraint_test"].get("primary_private", False)) + self.assertFalse( + tables["unique_constraint_test"].get("primary_private", False) + ) self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) self.assertTrue(tables["no_pk_test"].get("ignore", False)) @@ -189,10 +213,7 @@ def test_print_data(self) -> None: person_table = self.metadata.tables["person"] with self.engine.connect() as conn: person_rows = conn.execute(select(person_table)).mappings().fetchall() - person_data = { - row["person_id"]: row - for row in person_rows - } + person_data = {row["person_id"]: row for row in person_rows} name_set = {row["name"] for row in person_rows} person_headings = ["person_id", "name", "research_opt_out", "stored_from"] with self._get_cmd({}) as tc: @@ -224,11 +245,15 @@ def test_print_data(self) -> None: tc.reset() tc.do_data(f"{to_get_count} name 13") self.assertEqual(len(tc.column_items), 1) - self.assertEqual(set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set))) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) + ) tc.reset() tc.do_data(f"{to_get_count} name 16") self.assertEqual(len(tc.column_items), 1) - self.assertEqual(set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set))) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) + ) def test_list_tables(self): """Test that we can list the tables""" @@ -252,7 +277,7 @@ def test_list_tables(self): person_listed = False unique_constraint_test_listed = False no_pk_test_listed = False - for (text, args, kwargs) in tc.messages: + for text, args, kwargs in tc.messages: if args[2] == "person": self.assertFalse(person_listed) person_listed = True @@ -277,7 +302,8 @@ def test_list_tables(self): class ConfigureTablesInstrumentsTests(ConfigureTablesTests): - """ Testing configure-tables with the instrument.sql database. """ + """Testing configure-tables with the instrument.sql database.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -300,10 +326,28 @@ def test_sanity_checks_both(self): tc.reset() tc.do_quit("") self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, ("model", "manufacturer"), {})) - self.assertEqual(tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {})) - self.assertEqual(tc.messages[4], (TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, ("signature_model", "player"), {})) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) + self.assertEqual( + tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[4], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) def test_sanity_checks_warnings_only(self): config = { @@ -324,9 +368,25 @@ def test_sanity_checks_warnings_only(self): tc.do_vocabulary("") tc.reset() tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_CHANGING, ("manufacturer", "ignore", "vocabulary"), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, ("signature_model", "player"), {})) + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("manufacturer", "ignore", "vocabulary"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) def test_sanity_checks_errors_only(self): config = { @@ -347,16 +407,34 @@ def test_sanity_checks_errors_only(self): tc.do_empty("") tc.reset() tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_CHANGING, ("signature_model", "generate", "empty"), {})) - self.assertEqual(tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {})) - self.assertEqual(tc.messages[2], (TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, ("model", "manufacturer"), {})) + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("signature_model", "generate", "empty"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): - """ GeneratorCmd but mocked """ + """GeneratorCmd but mocked""" + def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: """ - Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples])""" + Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples]) + """ return { kw["name"]: (kw["index"], kw["fit"], kw["sample"].split("; ")) for (s, _, kw) in self.messages @@ -365,7 +443,8 @@ def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: class ConfigureGeneratorsTests(RequiresDBTestCase): - """ Testing configure-generators. """ + """Testing configure-generators.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -374,7 +453,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_null_configuration(self): - """ Test that the tables having null configuration does not break. """ + """Test that the tables having null configuration does not break.""" config = { "tables": None, } @@ -388,7 +467,7 @@ def test_null_configuration(self): self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) def test_null_table_configuration(self): - """ Test that a table having null configuration does not break. """ + """Test that a table having null configuration does not break.""" config = { "tables": { "model": None, @@ -415,10 +494,14 @@ def test_prompts(self) -> None: else: self.assertNotIn("[pk]", gc.prompt) gc.do_next("") - self.assertListEqual(gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})]) + self.assertListEqual( + gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})] + ) gc.reset() for table_name, table_meta in reversed(list(self.metadata.tables.items())): - for column_name, column_meta in reversed(list(table_meta.columns.items())): + for column_name, column_meta in reversed( + list(table_meta.columns.items()) + ): self.assertIn(table_name, gc.prompt) self.assertIn(column_name, gc.prompt) if column_meta.primary_key: @@ -426,19 +509,20 @@ def test_prompts(self) -> None: else: self.assertNotIn("[pk]", gc.prompt) gc.do_previous("") - self.assertListEqual(gc.messages, [(GeneratorCmd.ERROR_ALREADY_AT_START, (), {})]) + self.assertListEqual( + gc.messages, [(GeneratorCmd.ERROR_ALREADY_AT_START, (), {})] + ) gc.reset() bad_table_name = "notarealtable" gc.do_next(bad_table_name) - self.assertListEqual(gc.messages, [( - GeneratorCmd.ERROR_NO_SUCH_TABLE_OR_COLUMN, - (bad_table_name,), - {} - )]) + self.assertListEqual( + gc.messages, + [(GeneratorCmd.ERROR_NO_SUCH_TABLE_OR_COLUMN, (bad_table_name,), {})], + ) gc.reset() def test_set_generator_mimesis(self): - """ Test that we can set one generator to a mimesis generator. """ + """Test that we can set one generator to a mimesis generator.""" with self._get_cmd({}) as gc: TABLE = "model" COLUMN = "name" @@ -455,7 +539,7 @@ def test_set_generator_mimesis(self): ) def test_set_generator_distribution(self): - """ Test that we can set one generator to gaussian. """ + """Test that we can set one generator to gaussian.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "frequency" @@ -470,12 +554,17 @@ def test_set_generator_distribution(self): row_gen = row_gens[0] self.assertEqual(row_gen["name"], GENERATOR) self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen["kwargs"], { - "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', - }) + self.assertDictEqual( + row_gen["kwargs"], + { + "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', + "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") self.assertEqual( gc.config["src-stats"][0]["query"], @@ -483,7 +572,7 @@ def test_set_generator_distribution(self): ) def test_set_generator_distribution_directly(self): - """ Test that we can set one generator to gaussian without going through propose. """ + """Test that we can set one generator to gaussian without going through propose.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "frequency" @@ -494,7 +583,9 @@ def test_set_generator_distribution_directly(self): self.assertListEqual(gc.messages, []) gc.do_quit("") self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") self.assertEqual( gc.config["src-stats"][0]["query"], @@ -502,7 +593,7 @@ def test_set_generator_distribution_directly(self): ) def test_set_generator_choice(self): - """ Test that we can set one generator to uniform choice. """ + """Test that we can set one generator to uniform choice.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "frequency" @@ -517,19 +608,26 @@ def test_set_generator_choice(self): row_gen = row_gens[0] self.assertEqual(row_gen["name"], GENERATOR) self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen["kwargs"], { - "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', - }) + self.assertDictEqual( + row_gen["kwargs"], + { + "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}") + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) + self.assertEqual( + gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}" + ) self.assertEqual( gc.config["src-stats"][0]["query"], f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", ) def test_weighted_choice_generator_generates_choices(self): - """ Test that propose and compare show weighted_choice's values. """ + """Test that propose and compare show weighted_choice's values.""" with self._get_cmd({}) as gc: TABLE = "string" COLUMN = "position" @@ -546,7 +644,7 @@ def test_weighted_choice_generator_generates_choices(self): self.assertSubset(set(gc.columns[col_heading]), VALUES) def test_merge_columns(self): - """ Test that we can merge columns and set a multivariate generator """ + """Test that we can merge columns and set a multivariate generator""" TABLE = "string" COLUMN_1 = "frequency" COLUMN_2 = "position" @@ -586,7 +684,7 @@ def test_merge_columns(self): self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) def test_unmerge_columns(self): - """ Test that we can unmerge columns and generators are removed """ + """Test that we can unmerge columns and generators are removed""" TABLE = "string" COLUMN_1 = "frequency" COLUMN_2 = "position" @@ -597,7 +695,7 @@ def test_unmerge_columns(self): TABLE: { "row_generators": [ {"name": "gen1", "columns_assigned": [COLUMN_1, COLUMN_2]}, - { "name": REMAINING_GEN, "columns_assigned": [COLUMN_3] }, + {"name": REMAINING_GEN, "columns_assigned": [COLUMN_3]}, ] } } @@ -625,24 +723,28 @@ def test_unmerge_columns(self): self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) def test_old_generators_remain(self): - """ Test that we can set one generator and keep an old one. """ + """Test that we can set one generator and keep an old one.""" config = { "tables": { "string": { - "row_generators": [{ - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', - }, - }] + "row_generators": [ + { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', + }, + } + ] } }, - "src-stats": [{ - "name": "auto__string", - "query": 'SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string', - }] + "src-stats": [ + { + "name": "auto__string", + "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + } + ], } with self._get_cmd(config) as gc: TABLE = "model" @@ -663,18 +765,23 @@ def test_old_generators_remain(self): row_gen = row_gens[0] self.assertEqual(row_gen["name"], "dist_gen.normal") self.assertListEqual(row_gen["columns_assigned"], ["frequency"]) - self.assertDictEqual(row_gen["kwargs"], { - "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', - }) + self.assertDictEqual( + row_gen["kwargs"], + { + "mean": 'SRC_STATS["auto__string"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"][0]["stddev__frequency"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) - self.assertSetEqual(set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"}) + self.assertSetEqual( + set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} + ) self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__string") self.assertEqual( gc.config["src-stats"][0]["query"], "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", ) - + def test_aggregate_queries_merge(self): """ Test that we can set a generator that requires select aggregate clauses @@ -683,20 +790,24 @@ def test_aggregate_queries_merge(self): config = { "tables": { "string": { - "row_generators": [{ - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }, - }] + "row_generators": [ + { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', + }, + } + ] } }, - "src-stats": [{ - "name": "auto__string", - "query": 'SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string', - }] + "src-stats": [ + { + "name": "auto__string", + "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + } + ], } with self._get_cmd(copy.deepcopy(config)) as gc: COLUMN = "position" @@ -706,7 +817,9 @@ def test_aggregate_queries_merge(self): proposals = gc.get_proposals() gc.do_set(str(proposals[f"{GENERATOR}"][0])) gc.do_quit("") - row_gens: list[dict[str,any]] = gc.config["tables"]["string"]["row_generators"] + row_gens: list[dict[str, any]] = gc.config["tables"]["string"][ + "row_generators" + ] self.assertEqual(len(row_gens), 2) if row_gens[0]["name"] == GENERATOR: row_gen0 = row_gens[0] @@ -717,28 +830,41 @@ def test_aggregate_queries_merge(self): self.assertEqual(row_gen0["name"], GENERATOR) self.assertEqual(row_gen1["name"], "dist_gen.normal") self.assertListEqual(row_gen0["columns_assigned"], [COLUMN]) - self.assertDictEqual(row_gen0["kwargs"], { - "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', - }) + self.assertDictEqual( + row_gen0["kwargs"], + { + "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', + "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', + }, + ) self.assertListEqual(row_gen1["columns_assigned"], ["frequency"]) - self.assertDictEqual(row_gen1["kwargs"], { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }) + self.assertDictEqual( + row_gen1["kwargs"], + { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', + }, + ) self.assertEqual(len(gc.config["src-stats"]), 1) self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") - select_match = re.match(r'SELECT (.*) FROM string', gc.config["src-stats"][0]["query"]) - self.assertIsNotNone(select_match, "src_stats[0].query is not an aggregate select") - self.assertSetEqual(set(select_match.group(1).split(", ")), { - "AVG(frequency) AS mean__frequency", - "STDDEV(frequency) AS stddev__frequency", - f"AVG({COLUMN}) AS mean__{COLUMN}", - f"STDDEV({COLUMN}) AS stddev__{COLUMN}", - }) + select_match = re.match( + r"SELECT (.*) FROM string", gc.config["src-stats"][0]["query"] + ) + self.assertIsNotNone( + select_match, "src_stats[0].query is not an aggregate select" + ) + self.assertSetEqual( + set(select_match.group(1).split(", ")), + { + "AVG(frequency) AS mean__frequency", + "STDDEV(frequency) AS stddev__frequency", + f"AVG({COLUMN}) AS mean__{COLUMN}", + f"STDDEV({COLUMN}) AS stddev__{COLUMN}", + }, + ) def test_next_completion(self): - """ Test tab completion for the next command. """ + """Test tab completion for the next command.""" with self._get_cmd({}) as gc: self.assertSetEqual( set(gc.complete_next("m", "next m", 5, 6)), @@ -756,7 +882,9 @@ def test_next_completion(self): set(gc.complete_next("string.p", "next string.p", 5, 12)), {"string.position"}, ) - self.assertListEqual(gc.complete_next("string.q", "next string.q", 5, 12), []) + self.assertListEqual( + gc.complete_next("string.q", "next string.q", 5, 12), [] + ) self.assertListEqual(gc.complete_next("ww", "next ww", 5, 7), []) def test_compare_reports_privacy(self): @@ -799,10 +927,12 @@ def test_existing_configuration_remains(self): "primary_private": True, } }, - "src-stats": [{ - "name": "kraken", - "query": 'SELECT MAX(frequency) AS max_frequency FROM string', - }] + "src-stats": [ + { + "name": "kraken", + "query": "SELECT MAX(frequency) AS max_frequency FROM string", + } + ], } with self._get_cmd(config) as gc: COLUMN = "position" @@ -812,15 +942,12 @@ def test_existing_configuration_remains(self): proposals = gc.get_proposals() gc.do_set(str(proposals[f"{GENERATOR}"][0])) gc.do_quit("") - src_stats = { - stat["name"]: stat["query"] - for stat in gc.config["src-stats"] - } + src_stats = {stat["name"]: stat["query"] for stat in gc.config["src-stats"]} self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) self.assertTrue(gc.config["tables"]["string"]["primary_private"]) def test_empty_tables_are_not_configured(self): - """ Test that tables marked as empty are not configured. """ + """Test that tables marked as empty are not configured.""" config = { "tables": { "string": { @@ -830,13 +957,14 @@ def test_empty_tables_are_not_configured(self): } with self._get_cmd(copy.deepcopy(config)) as gc: gc.do_tables("") - table_names = { m[1][0] for m in gc.messages } + table_names = {m[1][0] for m in gc.messages} self.assertIn("model", table_names) self.assertNotIn("string", table_names) class GeneratorsOutputTests(GeneratesDBTestCase): - """ Testing choice generation. """ + """Testing choice generation.""" + dump_file_path = "choice.sql" database_name = "numbers" schema_name = "public" @@ -845,7 +973,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_create_with_sampled_choice(self): - """ Test that suppression works for choice and zipf_choice. """ + """Test that suppression works for choice and zipf_choice.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") @@ -869,7 +997,9 @@ def test_create_with_sampled_choice(self): self.assertIn("dist_gen.zipf_choice [sampled]", proposals) self.assertIn("dist_gen.choice [sampled and suppressed]", proposals) self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) - gc.do_set(str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0])) + gc.do_set( + str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0]) + ) gc.do_next("number_table.three") gc.reset() gc.do_propose("") @@ -899,7 +1029,7 @@ def test_create_with_sampled_choice(self): self.assertSetEqual(threes, {1, 2, 3, 4, 5}) def test_create_with_choice(self): - """ Smoke test normal choice works. """ + """Smoke test normal choice works.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") @@ -927,7 +1057,7 @@ def test_create_with_choice(self): self.assertSetEqual(twos, {1, 2, 3, 4, 5}) def test_create_with_weighted_choice(self): - """ Smoke test weighted choice. """ + """Smoke test weighted choice.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") @@ -936,12 +1066,16 @@ def test_create_with_weighted_choice(self): proposals = gc.get_proposals() self.assertIn("dist_gen.weighted_choice", proposals) self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) + self.assertIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] self.assertSubset(set(prop[2]), {"1", "4"}) gc.reset() gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" + col_heading = ( + f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" + ) self.assertIn(col_heading, set(gc.columns.keys())) self.assertSubset(set(gc.columns[col_heading]), {1, 4}) gc.do_set(str(prop[0])) @@ -951,7 +1085,9 @@ def test_create_with_weighted_choice(self): proposals = gc.get_proposals() self.assertIn("dist_gen.weighted_choice", proposals) self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) + self.assertIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) prop = proposals["dist_gen.weighted_choice"] self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.reset() @@ -966,7 +1102,9 @@ def test_create_with_weighted_choice(self): proposals = gc.get_proposals() self.assertIn("dist_gen.weighted_choice", proposals) self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertNotIn("dist_gen.weighted_choice [sampled and suppressed]", proposals) + self.assertNotIn( + "dist_gen.weighted_choice [sampled and suppressed]", proposals + ) prop = proposals["dist_gen.weighted_choice [sampled]"] self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.do_compare(str(prop[0])) @@ -993,10 +1131,12 @@ def test_create_with_weighted_choice(self): class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): - """ MissingnessCmd but mocked """ + """MissingnessCmd but mocked""" + class ConfigureMissingnessTests(RequiresDBTestCase): - """ Testing configure-missing. """ + """Testing configure-missing.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -1005,34 +1145,50 @@ def _get_cmd(self, config) -> TestMissingnessCmd: return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) def test_set_missingness_to_sampled(self): - """ Test that we can set one table to sampled missingness. """ + """Test that we can set one table to sampled missingness.""" with self._get_cmd({}) as mc: TABLE = "signature_model" mc.do_next(TABLE) mc.do_counts("") - self.assertListEqual(mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})]) - self.assertListEqual(mc.rows, [['player_id', 3], ['based_on', 2]]) + self.assertListEqual( + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})] + ) + self.assertListEqual(mc.rows, [["player_id", 3], ["based_on", 2]]) mc.do_sampled("") mc.do_quit("") self.assertDictEqual( mc.config, - { "tables": {TABLE: {"missingness_generators": [{ - "columns": ["player_id", "based_on"], - "kwargs": {"patterns": 'SRC_STATS["missing_auto__signature_model__0"]'}, - "name": "column_presence.sampled", - }]}}, - "src-stats": [{ - "name": "missing_auto__signature_model__0", - "query": ("SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" + { + "tables": { + TABLE: { + "missingness_generators": [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]' + }, + "name": "column_presence.sampled", + } + ] + } + }, + "src-stats": [ + { + "name": "missing_auto__signature_model__0", + "query": ( + "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null") - }] - } + " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" + ), + } + ], + }, ) class ConfigureMissingnessTests(GeneratesDBTestCase): - """ Testing configure-missing with generation. """ + """Testing configure-missing with generation.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -1041,7 +1197,7 @@ def _get_cmd(self, config) -> TestMissingnessCmd: return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) def test_create_with_missingness(self): - """ Test that we can sample real missingness and reproduce it. """ + """Test that we can sample real missingness and reproduce it.""" random.seed(45) # Configure the missingness table_name = "signature_model" @@ -1065,7 +1221,8 @@ def test_create_with_missingness(self): class GeneratorTests(GeneratesDBTestCase): - """ Testing configure-generators with generation. """ + """Testing configure-generators with generation.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" @@ -1074,7 +1231,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_set_null(self): - """ Test that we can sample real missingness and reproduce it. """ + """Test that we can sample real missingness and reproduce it.""" with self._get_cmd({}) as gc: gc.do_next("string.position") gc.do_set("dist_gen.constant") @@ -1091,7 +1248,9 @@ def test_set_null(self): gc.do_next("signature_model.based_on") gc.do_set("dist_gen.constant") # we have got to the end of the columns, but shouldn't have any errors - self.assertListEqual(gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})]) + self.assertListEqual( + gc.messages, [(GeneratorCmd.INFO_NO_MORE_TABLES, (), {})] + ) gc.reset() gc.do_quit("") config = gc.config @@ -1116,7 +1275,7 @@ def test_set_null(self): self.assertEqual(count, 3) def test_dist_gen_sampled_produces_ordered_src_stats(self): - """ Tests that choosing a sampled choice generator produces ordered src stats """ + """Tests that choosing a sampled choice generator produces ordered src stats""" with self._get_cmd({}) as gc: gc.do_next("signature_model.player_id") gc.do_set("dist_gen.zipf_choice [sampled]") @@ -1127,26 +1286,24 @@ def test_dist_gen_sampled_produces_ordered_src_stats(self): self.set_configuration(config) src_stats = self.get_src_stats(config) player_ids = [ - s["value"] - for s in src_stats["auto__signature_model__player_id"]["results"] + s["value"] for s in src_stats["auto__signature_model__player_id"]["results"] ] self.assertListEqual(player_ids, [2, 3, 1]) based_ons = [ - s["value"] - for s in src_stats["auto__signature_model__based_on"]["results"] + s["value"] for s in src_stats["auto__signature_model__based_on"]["results"] ] self.assertListEqual(based_ons, [1, 3, 2]) def assertAreTruncatedTo(self, xs, length): - maxlen = 0 - for x in xs: - newlen = len(x.strip("'\"")) - self.assertLessEqual(newlen, length) - maxlen = max(maxlen, newlen) - self.assertEqual(maxlen, length) + maxlen = 0 + for x in xs: + newlen = len(x.strip("'\"")) + self.assertLessEqual(newlen, length) + maxlen = max(maxlen, newlen) + self.assertEqual(maxlen, length) def test_varchar_ns_are_truncated(self): - """ Tests that mimesis generators for VARCHAR(N) truncate to N characters """ + """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" GENERATOR = "generic.text.quote" TABLE = "signature_model" COLUMN = "name" @@ -1176,9 +1333,9 @@ def test_varchar_ns_are_truncated(self): @dataclass class Stat: - n: int=0 - x: float=0 - x2: float=0 + n: int = 0 + x: float = 0 + x2: float = 0 def add(self, x: float) -> None: self.n += 1 @@ -1193,14 +1350,14 @@ def x_mean(self) -> float: def x_var(self) -> float: x = self.x - return (self.x2 - x*x/self.n)/(self.n - 1) + return (self.x2 - x * x / self.n) / (self.n - 1) @dataclass class Correlation(Stat): - y: float=0 - y2: float=0 - xy: float=0 + y: float = 0 + y2: float = 0 + xy: float = 0 def add(self, x: float, y: float) -> None: self.n += 1 @@ -1215,14 +1372,15 @@ def y_mean(self) -> float: def y_var(self) -> float: y = self.y - return (self.y2 - y*y/self.n)/(self.n - 1) + return (self.y2 - y * y / self.n) / (self.n - 1) def covar(self) -> float: - return (self.xy - self.x*self.y/self.n)/(self.n - 1) + return (self.xy - self.x * self.y / self.n) / (self.n - 1) class NullPartitionedTests(GeneratesDBTestCase): - """ Testing null-partitioned grouped multivariate generation. """ + """Testing null-partitioned grouped multivariate generation.""" + dump_file_path = "eav.sql" database_name = "eav" schema_name = "public" @@ -1236,7 +1394,7 @@ def _get_cmd(self, config) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) def test_create_with_null_partitioned_grouped_multivariate(self): - """ Test EAV for all columns. """ + """Test EAV for all columns.""" table_name = "measurement" generate_count = 800 with self._get_cmd({}) as gc: @@ -1287,9 +1445,9 @@ def test_create_with_null_partitioned_grouped_multivariate(self): # yes or no self.assertIsNone(row.first_value) self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'yes', 'no'}) + self.assertIn(row.third_value, {"yes", "no"}) one_count += 1 - if row.third_value == 'yes': + if row.third_value == "yes": one_yes_count += 1 elif row.type == 2: # positive correlation around 1.4, 1.8 @@ -1310,43 +1468,57 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertIsNone(row.third_value) four.add(row.first_value, row.second_value) elif row.type == 5: - self.assertIn(row.third_value, {'fish', 'fowl'}) + self.assertIn(row.third_value, {"fish", "fowl"}) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) - if row.third_value == 'fish': + if row.third_value == "fish": # mean 8.1 and sd 0.755 fish.add(row.first_value) else: # mean 11.2 and sd 1.114 fowl.add(row.first_value) # type 1 - self.assertAlmostEqual(one_count, generate_count * 5 / 20, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) # about 40% are yes - self.assertAlmostEqual(one_yes_count / one_count, 0.4, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_yes_count / one_count, 0.4, delta=generate_count * 0.4 + ) # type 2 - self.assertAlmostEqual(two.count(), generate_count * 3 / 20, delta=generate_count * 0.5) + self.assertAlmostEqual( + two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) self.assertAlmostEqual(two.x_var(), 0.21, delta=0.4) self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) self.assertAlmostEqual(two.y_var(), 0.07, delta=0.1) self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) # type 3 - self.assertAlmostEqual(three.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(two.covar(), -0.5, delta=0.5) # type 4 - self.assertAlmostEqual(four.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) # type 5/fish - self.assertAlmostEqual(fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) # type 5/fowl - self.assertAlmostEqual(fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2) + self.assertAlmostEqual( + fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): - """ Test EAV for all columns with sampled and suppressed generation. """ + """Test EAV for all columns with sampled and suppressed generation.""" table_name = "measurement" table2_name = "observation" generate_count = 800 @@ -1360,8 +1532,13 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): proposals = gc.get_proposals() self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn("null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + proposals, + ) + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) self.assertIn(dist_to_choose, proposals) prop = proposals[dist_to_choose] gc.reset() @@ -1377,7 +1554,9 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): gc.reset() gc.do_propose("") proposals = gc.get_proposals() - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) prop = proposals[dist_to_choose] gc.do_set(str(prop[0])) gc.do_quit("") @@ -1409,15 +1588,15 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): # yes or no self.assertIsNone(row.first_value) self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'yes', 'no'}) - if row.third_value == 'yes': + self.assertIn(row.third_value, {"yes", "no"}) + if row.third_value == "yes": one_yes_count += 1 one_count += 1 elif row.type == 5: - self.assertIn(row.third_value, {'fish', 'fowl'}) + self.assertIn(row.third_value, {"fish", "fowl"}) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) - if row.third_value == 'fish': + if row.third_value == "fish": # mean 8.1 and sd 0.755 fish.add(row.first_value) else: @@ -1427,15 +1606,23 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.assertEqual(len(types), 4) self.assertSubset({1, 5}, types) # type 1 - self.assertAlmostEqual(one_count, generate_count * 5 / 11, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) # about 40% are yes - self.assertAlmostEqual(one_yes_count / one_count, 0.4, delta=generate_count * 0.4) + self.assertAlmostEqual( + one_yes_count / one_count, 0.4, delta=generate_count * 0.4 + ) # type 5/fish - self.assertAlmostEqual(fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2) + self.assertAlmostEqual( + fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) # type 5/fowl - self.assertAlmostEqual(fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2) + self.assertAlmostEqual( + fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) stmt = select(self.metadata.tables[table2_name]) @@ -1446,38 +1633,52 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.assertEqual(row.type, 1) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {'ham', 'eggs'}) + self.assertIn(row.third_value, {"ham", "eggs"}) firsts.add(row.first_value) self.assertEqual(firsts.count(), 800) - self.assertAlmostEqual(firsts.x_mean(), 1.3, delta = generate_count * 0.3) + self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) class NonInteractiveTests(RequiresDBTestCase): """ Test the --spec SPEC_FILE option of configure-generators """ + dump_file_path = "eav.sql" database_name = "eav" schema_name = "public" @patch("datafaker.interactive.Path") - @patch("datafaker.interactive.csv.reader", return_value=iter([ - ["observation", "type", "dist_gen.weighted_choice"], - ["observation", "first_value", "dist_gen.weighted_choice"], - ["observation", "third_value", "dist_gen.weighted_choice"], - ])) - def test_non_interactive_configure_generators(self, mock_csv_reader: MagicMock, mock_path: MagicMock): + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + ["observation", "type", "dist_gen.weighted_choice"], + ["observation", "first_value", "dist_gen.weighted_choice"], + ["observation", "third_value", "dist_gen.weighted_choice"], + ] + ), + ) + def test_non_interactive_configure_generators( + self, mock_csv_reader: MagicMock, mock_path: MagicMock + ): """ test that we can set generators from a CSV file """ config = {} spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators(self.dsn, self.schema_name, self.metadata, config, spec_csv) + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) row_gens = { f"{table}{sorted(rg['columns_assigned'])}": rg["name"] for table, tables in config.get("tables", {}).items() for rg in tables.get("row_generators", []) } self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") - self.assertEqual(row_gens["observation['first_value']"], "dist_gen.weighted_choice") - self.assertEqual(row_gens["observation['third_value']"], "dist_gen.weighted_choice") + self.assertEqual( + row_gens["observation['first_value']"], "dist_gen.weighted_choice" + ) + self.assertEqual( + row_gens["observation['third_value']"], "dist_gen.weighted_choice" + ) diff --git a/tests/test_main.py b/tests/test_main.py index c3652a2..a37eaf4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -21,7 +21,13 @@ class TestCLI(DatafakerTestCase): @patch("datafaker.main.dict_to_metadata") @patch("datafaker.main.load_metadata_config") @patch("datafaker.main.create_db_vocab") - def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_meta: MagicMock, mock_config: MagicMock) -> None: + def test_create_vocab( + self, + mock_create: MagicMock, + mock_mdict: MagicMock, + mock_meta: MagicMock, + mock_config: MagicMock, + ) -> None: """Test the create-vocab sub-command.""" result = runner.invoke( app, @@ -31,7 +37,9 @@ def test_create_vocab(self, mock_create: MagicMock, mock_mdict: MagicMock, mock_ catch_exceptions=False, ) - mock_create.assert_called_once_with(mock_meta.return_value, mock_mdict.return_value, mock_config.return_value) + mock_create.assert_called_once_with( + mock_meta.return_value, mock_mdict.return_value, mock_config.return_value + ) self.assertSuccess(result) @patch("datafaker.main.read_config_file") @@ -159,10 +167,13 @@ def test_create_generators_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke(app, [ - "create-generators", - force_option, - ]) + result: Result = runner.invoke( + app, + [ + "create-generators", + force_option, + ], + ) mock_make.assert_called_once_with( mock_load_meta.return_value, @@ -335,11 +346,14 @@ def test_make_tables_with_force_enabled( for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): - result: Result = runner.invoke(app, [ - "make-tables", - force_option, - "--orm-file=tests/examples/example_orm.yaml", - ]) + result: Result = runner.invoke( + app, + [ + "make-tables", + force_option, + "--orm-file=tests/examples/example_orm.yaml", + ], + ) mock_make_tables.assert_called_once_with( mock_get_settings.return_value.src_dsn, @@ -359,7 +373,11 @@ def test_make_tables_with_force_enabled( @patch("datafaker.main.get_settings") @patch("datafaker.main.load_metadata", side_effect=["ms"]) def test_make_stats( - self, _lm: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock + self, + _lm: MagicMock, + mock_get_settings: MagicMock, + mock_make: MagicMock, + mock_path: MagicMock, ) -> None: """Test the make-stats sub-command.""" example_conf_path = "tests/examples/example_config.yaml" @@ -379,7 +397,9 @@ def test_make_stats( self.assertSuccess(result) with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - mock_make.assert_called_once_with(get_test_settings().src_dsn, config, "ms", None) + mock_make.assert_called_once_with( + get_test_settings().src_dsn, config, "ms", None + ) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" ) @@ -434,7 +454,11 @@ def test_make_stats_errors_if_no_src_dsn(self, mock_logger: MagicMock) -> None: @patch("datafaker.main.get_settings") @patch("datafaker.main.load_metadata") def test_make_stats_with_force_enabled( - self, mock_meta: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock + self, + mock_meta: MagicMock, + mock_get_settings: MagicMock, + mock_make: MagicMock, + mock_path: MagicMock, ) -> None: """Tests that the make-stats command overwrite files when instructed.""" test_config_file: str = "tests/examples/example_config.yaml" @@ -461,7 +485,10 @@ def test_make_stats_with_force_enabled( ) mock_make.assert_called_once_with( - test_settings.src_dsn, config_file_content, mock_meta.return_value, None + test_settings.src_dsn, + config_file_content, + mock_meta.return_value, + None, ) mock_path.return_value.write_text.assert_called_once_with( "some_stat: 0\n", encoding="utf-8" @@ -507,7 +534,9 @@ def test_remove_data( catch_exceptions=False, ) self.assertEqual(0, result.exit_code) - mock_remove.assert_called_once_with(mock_meta.return_value, mock_config.return_value) + mock_remove.assert_called_once_with( + mock_meta.return_value, mock_config.return_value + ) @patch("datafaker.main.read_config_file") @patch("datafaker.main.remove_db_vocab") @@ -528,12 +557,18 @@ def test_remove_vocab( ) self.assertEqual(0, result.exit_code) mock_read_config.assert_called_once_with("config.yaml") - mock_remove.assert_called_once_with(mock_d2m.return_value, mock_load_metadata.return_value, mock_read_config.return_value) + mock_remove.assert_called_once_with( + mock_d2m.return_value, + mock_load_metadata.return_value, + mock_read_config.return_value, + ) @patch("datafaker.main.remove_db_tables") @patch("datafaker.main.load_metadata_for_output") @patch("datafaker.main.read_config_file") - def test_remove_tables(self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock) -> None: + def test_remove_tables( + self, _: MagicMock, mock_meta: MagicMock, mock_remove: MagicMock + ) -> None: """Test the remove-tables command.""" result = runner.invoke( app, diff --git a/tests/test_make.py b/tests/test_make.py index 6100aa7..f43588a 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -9,37 +9,38 @@ from sqlalchemy.dialects.mysql.types import INTEGER from sqlalchemy.dialects.postgresql import UUID -from datafaker.make import ( - _get_provider_for_column, - make_src_stats, -) -from tests.utils import RequiresDBTestCase, GeneratesDBTestCase +from datafaker.make import _get_provider_for_column, make_src_stats +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase class TestMakeGenerators(GeneratesDBTestCase): """Test the make_table_generators function.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def test_make_table_generators(self) -> None: - """ Check that we can make a generators file. """ + """Check that we can make a generators file.""" config = { "tables": { "player": { - "row_generators": [{ - "name": "dist_gen.constant", - "kwargs": { - "value": '"Cave"', + "row_generators": [ + { + "name": "dist_gen.constant", + "kwargs": { + "value": '"Cave"', + }, + "columns_assigned": "given_name", }, - "columns_assigned": "given_name", - }, { - "name": "dist_gen.constant", - "kwargs": { - "value": '"Johnson"', + { + "name": "dist_gen.constant", + "kwargs": { + "value": '"Johnson"', + }, + "columns_assigned": "family_name", }, - "columns_assigned": "family_name", - }], + ], }, }, } @@ -96,7 +97,7 @@ def test_get_provider_for_column(self) -> None: ) self.assertEqual( generator_arguments, - { "length": "100" }, + {"length": "100"}, ) # UUID @@ -149,12 +150,15 @@ def check_make_stats_output(self, src_stats: dict) -> None: count_names = src_stats["count_names"]["results"] count_names.sort(key=lambda c: c["name"]) - self.assertListEqual(count_names, [ - {"num": 1, "name": "Miranda Rando-Generata"}, - {"num": 997, "name": "Randy Random"}, - {"num": 1, "name": "Testfried Testermann"}, - {"num": 1, "name": "Veronica Fyre"}, - ]) + self.assertListEqual( + count_names, + [ + {"num": 1, "name": "Miranda Rando-Generata"}, + {"num": 997, "name": "Randy Random"}, + {"num": 1, "name": "Testfried Testermann"}, + {"num": 1, "name": "Veronica Fyre"}, + ], + ) avg_person_id = src_stats["avg_person_id"]["results"] self.assertEqual(len(avg_person_id), 1) diff --git a/tests/test_providers.py b/tests/test_providers.py index 9cc03c5..aedb693 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.declarative import declarative_base from datafaker import providers -from tests.utils import RequiresDBTestCase, DatafakerTestCase +from tests.utils import DatafakerTestCase, RequiresDBTestCase # pylint: disable=invalid-name Base = declarative_base() @@ -37,6 +37,7 @@ def test_bytes(self) -> None: class ColumnValueProviderTestCase(RequiresDBTestCase): """Tests for the ColumnValueProvider class.""" + dump_file_path = "providers.dump" def setUp(self) -> None: diff --git a/tests/test_remove.py b/tests/test_remove.py index 660d6cb..bfbb787 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -1,25 +1,25 @@ """Tests for the remove module.""" from unittest.mock import MagicMock, patch +from sqlalchemy import func, inspect, select + from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict from datafaker.settings import Settings -from sqlalchemy import func, inspect, select from tests.utils import RequiresDBTestCase class RemoveThingsTestCase(RequiresDBTestCase): - """ Tests for ``remove-`` commands. """ + """Tests for ``remove-`` commands.""" + dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" def count_rows(self, connection, table_name: str) -> int | None: - return connection.execute(select( - func.count() - ).select_from( - self.metadata.tables[table_name] - )).scalar() + return connection.execute( + select(func.count()).select_from(self.metadata.tables[table_name]) + ).scalar() @patch("datafaker.remove.get_settings") def test_remove_data(self, mock_get_settings: MagicMock): @@ -28,12 +28,15 @@ def test_remove_data(self, mock_get_settings: MagicMock): dst_dsn=self.dsn, _env_file=None, ) - remove_db_data(self.metadata, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_data( + self.metadata, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) with self.engine.connect() as conn: self.assertGreater(self.count_rows(conn, "manufacturer"), 0) self.assertGreater(self.count_rows(conn, "model"), 0) @@ -43,19 +46,22 @@ def test_remove_data(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-data raises if dst DSN is missing. """ + """Test that remove-data raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: - remove_db_data(self.metadata, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_data( + self.metadata, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) self.assertEqual( context_manager.exception.args[0], "Missing destination database settings" ) @@ -70,8 +76,8 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) config = { "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, } } remove_db_data(self.metadata, config) @@ -85,7 +91,7 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-vocab raises if dst DSN is missing. """ + """Test that remove-vocab raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, @@ -93,12 +99,16 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: ) with self.assertRaises(AssertionError) as context_manager: meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) - remove_db_vocab(self.metadata, meta_dict, { - "tables": { - "manufacturer": { "vocabulary_table": True }, - "model": { "vocabulary_table": True }, - } - }) + remove_db_vocab( + self.metadata, + meta_dict, + { + "tables": { + "manufacturer": {"vocabulary_table": True}, + "model": {"vocabulary_table": True}, + } + }, + ) self.assertEqual( context_manager.exception.args[0], "Missing destination database settings" ) @@ -120,7 +130,7 @@ def test_remove_tables(self, mock_get_settings: MagicMock): @patch("datafaker.remove.get_settings") def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: - """ Test that remove-vocab raises if dst DSN is missing. """ + """Test that remove-vocab raises if dst DSN is missing.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 41a7747..81e9eea 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -40,6 +40,7 @@ class UniqueGeneratorTestCase(RequiresDBTestCase): and b which are boolean, and c which is a text column. There is a joint unique constraint on a and b, and a separate unique constraint on c. """ + dump_file_path = "unique_generator.dump" def setUp(self) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 0eca2b1..2640a9e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import os import sys from pathlib import Path -from unittest.mock import patch, MagicMock, call +from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert from sqlalchemy.orm import declarative_base @@ -13,7 +13,7 @@ import_file, read_config_file, ) -from tests.utils import RequiresDBTestCase, DatafakerTestCase +from tests.utils import DatafakerTestCase, RequiresDBTestCase # pylint: disable=invalid-name Base = declarative_base() @@ -85,7 +85,9 @@ def test_download_table(self) -> None: conn.execute(insert(MyTable).values({"id": 1})) conn.commit() - download_table(MyTable.__table__, self.engine, self.mytable_file_path, compress=False) + download_table( + MyTable.__table__, self.engine, self.mytable_file_path, compress=False + ) # The .strip() gets rid of any possible empty lines at the end of the file. with Path("../examples/expected.yaml").open(encoding="utf-8") as yamlfile: @@ -108,124 +110,219 @@ def test_warns_of_invalid_config(self) -> None: "The config file is invalid: %s", "'a' is not of type 'integer'" ) + class TestUtils(DatafakerTestCase): - """ Miscellaneous tests. """ + """Miscellaneous tests.""" + def test_generators_require_stats(self) -> None: - """ Test that we can tell if a configuration requires SRC_STATS or not. """ - self.assertTrue(generators_require_stats({ - "object_instantiation": { - "mygen": {"name": "MyGen", "kwargs": {"a": '1 + SRC_STATS["my"]["results"][0]'}} - } - })) - self.assertTrue(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "kwargs": {"a": '[None] + SRC_STATS["my"]["results"]'}, - }] - })) - self.assertTrue(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "args": ['(SRC_STATS["my"]["results"])'], - }] - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "missingness_generators":[{ - "name": "msg", - "kwargs": {"a": '[SRC_STATS["my"], SRC_STATS["theirs"]]'}, - "columns_assigned": ["a"], - }] + """Test that we can tell if a configuration requires SRC_STATS or not.""" + self.assertTrue( + generators_require_stats( + { + "object_instantiation": { + "mygen": { + "name": "MyGen", + "kwargs": {"a": '1 + SRC_STATS["my"]["results"][0]'}, + } + } } - } - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "kwargs": {"a": 'SRC_STATS["ifu"]["results"]'}, - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "kwargs": {"a": '[None] + SRC_STATS["my"]["results"]'}, + } + ] } - } - })) - self.assertTrue(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['SRC_STATS'], - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "args": ['(SRC_STATS["my"]["results"])'], + } + ] } - } - })) - self.assertFalse(generators_require_stats({ - "object_instantiation": { - "mygen": {"name": "MyGen", "kwargs": {"a": 1}} - } - })) - self.assertFalse(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "kwargs": {"a": '[None]'}, - }] - })) - self.assertFalse(generators_require_stats({ - "story_generators": [{ - "name": "msg", - "args": ['(SRC_STATS_["my"]["results"])'], - }] - })) - self.assertFalse(generators_require_stats({ - "missingness_generators": [{ - "name": "msg", - "kwargs": {"a": '"SRC_STATS"'}, - "columns_assigned": ["a"], - }] - })) - self.assertFalse(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "kwargs": {"a": 'SRC_STAT["ifu"]["results"]'}, - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "missingness_generators": [ + { + "name": "msg", + "kwargs": { + "a": '[SRC_STATS["my"], SRC_STATS["theirs"]]' + }, + "columns_assigned": ["a"], + } + ] + } + } } - } - })) - self.assertFalse(generators_require_stats({ - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['SRC_STATSS'], - "columns_assigned": ["a"], - }] + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "kwargs": {"a": 'SRC_STATS["ifu"]["results"]'}, + "columns_assigned": ["a"], + } + ] + } + } } - } - })) + ) + ) + self.assertTrue( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["SRC_STATS"], + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "object_instantiation": { + "mygen": {"name": "MyGen", "kwargs": {"a": 1}} + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "kwargs": {"a": "[None]"}, + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "story_generators": [ + { + "name": "msg", + "args": ['(SRC_STATS_["my"]["results"])'], + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "missingness_generators": [ + { + "name": "msg", + "kwargs": {"a": '"SRC_STATS"'}, + "columns_assigned": ["a"], + } + ] + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "kwargs": {"a": 'SRC_STAT["ifu"]["results"]'}, + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) + self.assertFalse( + generators_require_stats( + { + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["SRC_STATSS"], + "columns_assigned": ["a"], + } + ] + } + } + } + ) + ) @patch("datafaker.utils.logger") def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): - generators_require_stats({ - "story_generators": [ - {"name": "my_story_gen", "kwargs": {"b": "'unclosed"}} - ], - "tables": { - "things": { - "row_generators": [{ - "name": "MyGen", - "args": ['1 2'], - "columns_assigned": ["a"], - }] - } + generators_require_stats( + { + "story_generators": [ + {"name": "my_story_gen", "kwargs": {"b": "'unclosed"}} + ], + "tables": { + "things": { + "row_generators": [ + { + "name": "MyGen", + "args": ["1 2"], + "columns_assigned": ["a"], + } + ] + } + }, } - }) - logger.error.assert_has_calls([ - call("Syntax error in argument %s of %s: %s\n%s\n%s", "b", "story_generators[0]", "unterminated string literal (detected at line 1)", "'unclosed", " ^"), - call("Syntax error in argument %d of %s: %s\n%s\n%s", 1, "tables.things.row_generators[0]", "invalid syntax", "1 2", " ^"), - ]) + ) + logger.error.assert_has_calls( + [ + call( + "Syntax error in argument %s of %s: %s\n%s\n%s", + "b", + "story_generators[0]", + "unterminated string literal (detected at line 1)", + "'unclosed", + " ^", + ), + call( + "Syntax error in argument %d of %s: %s\n%s\n%s", + 1, + "tables.things.row_generators[0]", + "invalid syntax", + "1 2", + " ^", + ), + ] + ) diff --git a/tests/utils.py b/tests/utils.py index 08850f8..a6eb593 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,25 +1,26 @@ """Utilities for testing.""" import asyncio -from functools import lru_cache import os -from pathlib import Path import shutil -from sqlalchemy.schema import MetaData -from subprocess import run -import testing.postgresql import traceback +from functools import lru_cache +from pathlib import Path +from subprocess import run +from tempfile import mkstemp from typing import Any from unittest import TestCase, skipUnless -import yaml +import testing.postgresql +import yaml from sqlalchemy import MetaData -from tempfile import mkstemp +from sqlalchemy.schema import MetaData from datafaker import settings from datafaker.create import create_db_data_into -from datafaker.make import make_tables_file, make_src_stats, make_table_generators +from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import import_file, sorted_non_vocabulary_tables, create_db_engine +from datafaker.utils import create_db_engine, import_file, sorted_non_vocabulary_tables + class SysExit(Exception): """To force the function to exit as sys.exit() would.""" @@ -68,10 +69,10 @@ def assertFailure(self, result: Any) -> None: # pylint: disable=invalid-name self.assertReturnCode(result, 1) def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-name - """ Assert that the result has no exception. """ + """Assert that the result has no exception.""" if result.exception is None: return - self.fail(''.join(traceback.format_exception(result.exception))) + self.fail("".join(traceback.format_exception(result.exception))) def assertSubset(self, set1, set2, msg=None): """Assert a set is a (non-strict) subset. @@ -85,22 +86,23 @@ def assertSubset(self, set1, set2, msg=None): try: difference = set1.difference(set2) except TypeError as e: - self.fail('invalid type when attempting set difference: %s' % e) + self.fail("invalid type when attempting set difference: %s" % e) except AttributeError as e: - self.fail('first argument does not support set difference: %s' % e) + self.fail("first argument does not support set difference: %s" % e) if not difference: return lines = [] if difference: - lines.append('Items in the first set but not the second:') + lines.append("Items in the first set but not the second:") for item in difference: lines.append(repr(item)) - standardMsg = '\n'.join(lines) + standardMsg = "\n".join(lines) self.fail(self._formatMessage(msg, standardMsg)) + @skipUnless(shutil.which("psql"), "need to find 'psql': install PostgreSQL to enable") class RequiresDBTestCase(DatafakerTestCase): """ @@ -112,6 +114,7 @@ class RequiresDBTestCase(DatafakerTestCase): to get an engine to access the database and self.metadata to get metadata reflected from that engine. """ + schema_name = None use_asyncio = False examples_dir = "tests/examples" @@ -201,13 +204,15 @@ def get_src_stats(self, config) -> dict[str, any]: make_src_stats(self.dsn, config, self.metadata, self.schema_name) ) loop.close() - (self.stats_fd, self.stats_file_path) = mkstemp(".yaml", "src_stats_", text=True) + (self.stats_fd, self.stats_file_path) = mkstemp( + ".yaml", "src_stats_", text=True + ) with os.fdopen(self.stats_fd, "w", encoding="utf-8") as stats_fh: stats_fh.write(yaml.dump(src_stats)) return src_stats def create_generators(self, config) -> None: - """ ``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py`` """ + """``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py``""" datafaker_content = make_table_generators( self.metadata, config, @@ -220,12 +225,12 @@ def create_generators(self, config) -> None: datafaker_fh.write(datafaker_content) def remove_data(self, config): - """ Remove source data from the DB. """ + """Remove source data from the DB.""" # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) def create_data(self, config, num_passes=1): - """ Create fake data in the DB. """ + """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) table_generator_dict = datafaker_module.table_generator_dict From bb14ced8f721077c0a548f3de569a39660a5aa2f Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 23:02:28 +0100 Subject: [PATCH 02/35] Fixed a test --- tests/test_interactive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 9487284..8bfb7bc 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1514,8 +1514,8 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertAlmostEqual( fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): """Test EAV for all columns with sampled and suppressed generation.""" From cda6164fbfe606f92b028520be0051e51b5c55de Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 23:31:03 +0100 Subject: [PATCH 03/35] base and create mypy fixes --- datafaker/base.py | 72 +++++++++++++++++++++--------------- datafaker/create.py | 26 ++++++++----- datafaker/templates/df.py.j2 | 2 +- tests/test_interactive.py | 4 +- 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 56315a4..0a3cf41 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import Any, Callable, Generator, TypeVar import numpy as np import yaml @@ -24,13 +24,16 @@ ) +_T = TypeVar("_T") + + @functools.cache -def zipf_weights(size): +def zipf_weights(size: int) -> list[float]: total = sum(map(lambda n: 1 / n, range(1, size + 1))) return [1 / (n * total) for n in range(1, size + 1)] -def merge_with_constants(xs: list, constants_at: dict[int, any]): +def merge_with_constants(xs: list[_T], constants_at: dict[int, _T]) -> Generator[_T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. :param constants_at: A map of indices to objects that must be placed at @@ -59,41 +62,41 @@ def merge_with_constants(xs: list, constants_at: dict[int, any]): class NothingToGenerateException(Exception): - def __init__(self, message): + def __init__(self, message: str): super().__init__(message) class DistributionGenerator: root3 = math.sqrt(3) - def __init__(self): + def __init__(self) -> None: self.np_gen = np.random.default_rng() - def uniform(self, low, high) -> float: + def uniform(self, low: float, high: float) -> float: return random.uniform(float(low), float(high)) - def uniform_ms(self, mean, sd) -> float: + def uniform_ms(self, mean: float, sd: float) -> float: m = float(mean) h = self.root3 * float(sd) return random.uniform(m - h, m + h) - def normal(self, mean, sd) -> float: + def normal(self, mean: float, sd: float) -> float: return random.normalvariate(float(mean), float(sd)) - def lognormal(self, logmean, logsd) -> float: + def lognormal(self, logmean: float, logsd: float) -> float: return random.lognormvariate(float(logmean), float(logsd)) - def choice(self, a): + def choice(self, a: list[_T]) -> _T: c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c - def zipf_choice(self, a, n=None): + def zipf_choice(self, a: list[_T], n: int | None=None) -> _T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] return c["value"] if type(c) is dict and "value" in c else c - def weighted_choice(self, a: list[dict[str, any]]) -> list[any]: + def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ Choice weighted by the count in the original dataset. :param a: a list of dicts, each with a ``value`` key @@ -110,10 +113,10 @@ def weighted_choice(self, a: list[dict[str, any]]) -> list[any]: c = random.choices(vs, weights=counts)[0] return c - def constant(self, value): + def constant(self, value: _T) -> _T: return value - def multivariate_normal_np(self, cov): + def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: rank = int(cov["rank"]) if rank == 0: return np.empty(shape=(0,)) @@ -127,7 +130,7 @@ def multivariate_normal_np(self, cov): ] return self.np_gen.multivariate_normal(mean, covs) - def _select_group(self, alts: list[dict[str, any]]): + def _select_group(self, alts: list[dict[str, Any]]) -> Any: """ Choose one of the ``alts`` weighted by their ``"count"`` elements. """ @@ -148,14 +151,14 @@ def _select_group(self, alts: list[dict[str, any]]): return alt raise Exception("Internal error: ran out of choices in _select_group") - def _find_constants(self, result: dict[str, any]): + def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: """ Find all keys ``kN``, returning a dictionary of ``N: kNN``. This can be passed into ``merge_with_constants`` as the ``constants_at`` argument. """ - out: dict[int, any] = {} + out: dict[int, Any] = {} for k, v in result.items(): if k.startswith("k") and k[1:].isnumeric(): out[int(k[1:])] = v @@ -171,7 +174,7 @@ def _find_constants(self, result: dict[str, any]): "with_constants_at", } - def multivariate_normal(self, cov): + def multivariate_normal(self, cov: dict[str, Any]) -> list[float]: """ Produce a list of values pulled from a multivariate distribution. @@ -182,9 +185,10 @@ def multivariate_normal(self, cov): ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. :return: list of ``rank`` floating point values """ - return self.multivariate_normal_np(cov).tolist() + out: list[float] = self.multivariate_normal_np(cov).tolist() + return out - def multivariate_lognormal(self, cov): + def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: """ Produce a list of values pulled from a multivariate distribution. @@ -196,16 +200,23 @@ def multivariate_lognormal(self, cov): are all the means and covariants of the logs of the data. :return: list of ``rank`` floating point values """ - return np.exp(self.multivariate_normal_np(cov)).tolist() + out: list[Any] = np.exp(self.multivariate_normal_np(cov)).tolist() + return out - def grouped_multivariate_normal(self, covs): + def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: + """ + Produce a list of values pulled from a set of multivariate distributions. + """ cov = self._select_group(covs) logger.debug("Multivariate normal group selected: %s", cov) constants = self._find_constants(cov) nums = self.multivariate_normal(cov) return list(merge_with_constants(nums, constants)) - def grouped_multivariate_lognormal(self, covs): + def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: + """ + Produce a list of values pulled from a set of multivariate distributions. + """ cov = self._select_group(covs) logger.debug("Multivariate lognormal group selected: %s", cov) constants = self._find_constants(cov) @@ -217,8 +228,8 @@ def _check_generator_name(self, name: str) -> None: raise Exception("%s is not a permitted generator", name) def alternatives( - self, alternative_configs: list[dict[str, any]], counts: list[int] | None - ): + self, alternative_configs: list[dict[str, Any]], counts: list[dict[str, int]] | None + ) -> Any: """ A generator that picks between other generators. @@ -227,6 +238,9 @@ def alternatives( how often to use this alternative; "name" -- which generator for this partition, for example "composite"; "params" -- the parameters for this alternative. + :param counts: A list of weights for each alternative. If None, the + "count" value of each alternative is used. Each count is a dict + with a "count" key. :return: list of values """ if counts is not None: @@ -246,8 +260,8 @@ def alternatives( return getattr(self, name)(**alt["params"]) def with_constants_at( - self, constants_at: list[int], subgen: str, params: dict[str, any] - ): + self, constants_at: dict[int, _T], subgen: str, params: dict[str, _T] + ) -> list[_T]: if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -258,7 +272,7 @@ def with_constants_at( logger.debug("Merging constants %s", constants_at) return list(merge_with_constants(subout, constants_at)) - def truncated_string(self, subgen_fn, params, length): + def truncated_string(self, subgen_fn: Callable[..., list[_T]], params: dict, length: int) -> list[_T]: """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: @@ -340,7 +354,7 @@ def load(self, connection: Connection, base_path: Path = Path(".")) -> None: class ColumnPresence: - def sampled(self, patterns): + def sampled(self, patterns: list[dict[str, Any]]) -> set[Any]: total = 0 for pattern in patterns: total += pattern.get("row_count", 0) diff --git a/datafaker/create.py b/datafaker/create.py index f522876..e902ec3 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -53,15 +53,15 @@ def create_db_vocab( metadata: MetaData, meta_dict: dict[str, Any], config: Mapping, - base_path: pathlib.Path | None = pathlib.Path("."), -) -> int: + base_path: pathlib.Path = pathlib.Path("."), +) -> list[str]: """ Load vocabulary tables from files. - arguments: - metadata: The schema of the database - meta_dict: The simple description of the schema from --orm-file - config: The configuration from --config-file + :param metadata: The schema of the database + :param meta_dict: The simple description of the schema from --orm-file + :param config: The configuration from --config-file + :return: List of table names loaded. """ settings = get_settings() dst_dsn: str = settings.dst_dsn or "" @@ -151,6 +151,8 @@ def __init__( self._table_dict: Mapping[str, Table] = table_dict self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict self._dst_conn: Connection = dst_conn + self._table_name: str | None + self._final_values: dict[str, Any] | None = None try: name, self._story = next(self._stories) logger.info("Generating data for story '%s'", name) @@ -165,13 +167,13 @@ def is_ended(self) -> bool: """ return self._table_name is None - def has_table(self, table_name: str): + def has_table(self, table_name: str) -> bool: """ Do we have a row for table table_name? """ return table_name == self._table_name - def table_name(self) -> str: + def table_name(self) -> str | None: """ The name of the current table (or None if no more stories to process) """ @@ -182,10 +184,12 @@ def insert(self) -> None: Perform the insert. Call this after __init__ or next, and after checking that is_ended returns False. """ + if self._table_name is None: + raise StopIteration("StoryIterator.insert after is_ended") table = self._table_dict[self._table_name] if table.name in self._table_generator_dict: table_generator = self._table_generator_dict[table.name] - default_values = table_generator(self._dst_conn, random.random) + default_values = table_generator(self._dst_conn) else: default_values = {} insert_values = {**default_values, **self._provided_values} @@ -273,7 +277,7 @@ def populate( with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): stmt = insert(table).values( - table_generator(dst_conn, random.random) + table_generator(dst_conn) ) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 @@ -286,6 +290,8 @@ def populate( while not story_iterator.is_ended(): story_iterator.insert() t = story_iterator.table_name() + if t is None: + raise Exception("Internal error") row_counts[t] = row_counts.get(t, 0) + 1 story_iterator.next() diff --git a/datafaker/templates/df.py.j2 b/datafaker/templates/df.py.j2 index 28c9582..87c84e4 100644 --- a/datafaker/templates/df.py.j2 +++ b/datafaker/templates/df.py.j2 @@ -55,7 +55,7 @@ class {{ table_data.class_name }}(TableGenerator): def __init__(self): self.initialized = False - def __call__(self, dst_db_conn, get_random): + def __call__(self, dst_db_conn): if not self.initialized: {% for constraint in table_data.unique_constraints %} query_text = f"SELECT {% diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 8bfb7bc..5116377 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1623,8 +1623,8 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.assertAlmostEqual( fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) - self.assertAlmostEqual(fish.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fish.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) stmt = select(self.metadata.tables[table2_name]) rows = conn.execute(stmt).fetchall() firsts = Stat() From e63c8bc24d80c6195547c7109e6e9313b38df3a0 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 3 Oct 2025 23:51:48 +0100 Subject: [PATCH 04/35] Some mypy fixes to generator.py --- datafaker/generators.py | 83 +++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/datafaker/generators.py b/datafaker/generators.py index 1ea0a8f..0a3f8c0 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -5,12 +5,13 @@ import decimal import math import re +import typing from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Callable, Iterable, TypeVar +from typing import Any, Callable, Iterable, Self, TypeVar import mimesis import mimesis.locales @@ -108,18 +109,18 @@ def custom_queries(self) -> dict[str, dict[str, str]]: return {} @abstractmethod - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: """ The kwargs (summary statistics) this generator is instantiated with. """ @abstractmethod - def generate_data(self, count) -> list[any]: + def generate_data(self, count: int) -> list[Any]: """ Generate 'count' random data points for this column. """ - def fit(self, default=None) -> float | None: + def fit(self, default: float = None) -> float | None: """ Return a value representing how well the distribution fits the real source data. @@ -138,7 +139,7 @@ class PredefinedGenerator(Generator): AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') - def _get_src_stats_mentioned(self, val) -> set[str]: + def _get_src_stats_mentioned(self, val: Any) -> set[str]: if not val: return set() if type(val) is str: @@ -159,8 +160,8 @@ def _get_src_stats_mentioned(self, val) -> set[str]: def __init__( self, table_name: str, - generator_object: Mapping[str, any], - config: Mapping[str, any], + generator_object: Mapping[str, Any], + config: Mapping[str, Any], ): """ Initialise a generator from a config.yaml. @@ -226,13 +227,13 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: def custom_queries(self) -> dict[str, dict[str, str]]: return self._custom_queries - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: # Run the queries from nominal_kwargs # ... logger.error("PredefinedGenerator.actual_kwargs not implemented yet") return {} - def generate_data(self, count) -> list[any]: + def generate_data(self, count: int) -> list[Any]: # Call the function if we can. This could be tricky... # ... logger.error("PredefinedGenerator.generate_data not implemented yet") @@ -286,7 +287,9 @@ def __init__( self.stddev = stddev @classmethod - def make_buckets(_cls, engine: Engine, table_name: str, column_name: str): + def make_buckets( + _cls, engine: Engine, table_name: str, column_name: str + ) -> Self | None: """ Construct a Buckets object. @@ -391,7 +394,7 @@ class MimesisGenerator(MimesisGeneratorBase): def __init__( self, function_name: str, - value_fn: Callable[[any], float] | None = None, + value_fn: Callable[[Any], float] | None = None, buckets: Buckets | None = None, ): """ @@ -430,16 +433,16 @@ def __init__( self, function_name: str, length: int, - value_fn: Callable[[any], float] | None = None, + value_fn: Callable[[Any], float] | None = None, buckets: Buckets | None = None, ): self._length = length super().__init__(function_name, value_fn, buckets) - def function_name(self): + def function_name(self) -> str: return "dist_gen.truncated_string" - def name(self): + def name(self) -> str: return f"{self._name} [truncated to {self._length}]" def nominal_kwargs(self): @@ -998,7 +1001,7 @@ def __init__( self._annotation = "sampled and suppressed" @abstractmethod - def get_estimated_counts(counts): + def get_estimated_counts(self, counts): """ The counts that we would expect if this distribution was the correct one. """ @@ -1008,7 +1011,7 @@ def nominal_kwargs(self): "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } - def name(self): + def name(self) -> str: n = super().name() if self._annotation is None: return n @@ -1029,24 +1032,24 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None): + def fit(self, default=None) -> float | None: return default if self._fit is None else self._fit class ZipfChoiceGenerator(ChoiceGenerator): - def get_estimated_counts(self, counts): + def get_estimated_counts(self, counts: list[int]) -> list[int]: return list(zipf_distribution(sum(counts), len(counts))) - def function_name(self): + def function_name(self) -> str: return "dist_gen.zipf_choice" - def generate_data(self, count): + def generate_data(self, count: int) -> list[float]: return [ dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) ] -def uniform_distribution(total, bins): +def uniform_distribution(total, bins: int) -> typing.Generator[int, None, None]: p = total // bins n = total % bins for _ in range(0, n): @@ -1108,7 +1111,7 @@ def get_generators(self, columns: list[Column], engine: Engine): values = [] # The values found counts = [] # The number or each value cvs: list[ - dict[str, any] + dict[str, Any] ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f @@ -1138,14 +1141,14 @@ def get_generators(self, columns: list[Column], engine: Engine): values = [] # All values found counts = [] # The number or each value cvs: list[ - dict[str, any] + dict[str, Any] ] = [] # list of dicts with keys "v" and "count" values_not_suppressed = ( [] ) # All values found more than SUPPRESS_COUNT times counts_not_suppressed = [] # The number for each value not suppressed cvs_not_suppressed: list[ - dict[str, any] + dict[str, Any] ] = [] # list of dicts with keys "v" and "count" for result in results: c = result.f @@ -1229,10 +1232,10 @@ def function_name(self) -> str: def nominal_kwargs(self) -> dict[str, str]: return {"value": self.repr} - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: return {"value": self.value} - def generate_data(self, count) -> list[any]: + def generate_data(self, count) -> list[Any]: return [self.value for _ in range(count)] @@ -1289,13 +1292,13 @@ def custom_queries(self): } } - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: """ The kwargs (summary statistics) this generator is instantiated with. """ return {"cov": self._covariates} - def generate_data(self, count) -> list[any]: + def generate_data(self, count) -> list[Any]: """ Generate 'count' random data points for this column. """ @@ -1372,7 +1375,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1443,9 +1446,9 @@ class RowPartition: # list of included column values (so once the generator has # been run and the included_choice values have been # added): {index: value} - constant_outputs: dict[int, any] + constant_outputs: dict[int, Any] # The actual covariates from the source database - covariates: dict[str, float] + covariates: list[dict[str, float]] def comment(self) -> str: caveat = "" @@ -1506,10 +1509,10 @@ def __init__( else: self._name = f"null-partitioned {function_name}" - def name(self): + def name(self) -> str: return self._name - def function_name(self): + def function_name(self) -> str: return "dist_gen.alternatives" def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition): @@ -1599,7 +1602,7 @@ def _actual_kwargs_with_combinations(self, partition: RowPartition): }, } - def actual_kwargs(self) -> dict[str, any]: + def actual_kwargs(self) -> dict[str, Any]: """ The kwargs (summary statistics) this generator is instantiated with. """ @@ -1611,7 +1614,7 @@ def actual_kwargs(self) -> dict[str, any]: "counts": self._partition_counts, } - def generate_data(self, count) -> list[any]: + def generate_data(self, count) -> list[Any]: """ Generate 'count' random data points for this column. """ @@ -1630,7 +1633,7 @@ def is_numeric(col: Column) -> bool: T = TypeVar("T") -def powerset(input: Iterable[T]) -> Iterable[Iterable[T]]: +def powerset(input: list[T]) -> Iterable[Iterable[T]]: """Returns a list of all sublists of""" return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) @@ -1736,7 +1739,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) @@ -1784,7 +1787,7 @@ def get_generators(self, columns: list[Column], engine: Engine): partition_def.nones, {}, ) - gens = [] + gens: list[Generator] = [] try: with engine.connect() as connection: partition_query_max = self.get_partition_count_query( @@ -1834,7 +1837,7 @@ def _execute_partition_queries( self, connection: Connection, partitions: dict[int, RowPartition], - ): + ) -> bool: """ Execute the query in each partition, filling in the covariates. :return: True if all the partitions work, False if any of them fail. @@ -1864,7 +1867,7 @@ def query_var(self, column: str) -> str: @lru_cache(1) -def everything_factory(): +def everything_factory() -> GeneratorFactory: return MultiGeneratorFactory( [ MimesisStringGeneratorFactory(), From f28d7b9bd752f7e48d87da14c37190fb138a1929 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Sun, 5 Oct 2025 23:38:28 +0100 Subject: [PATCH 05/35] Fixed variances in tests --- tests/examples/eav.sql | 4 ++-- tests/test_interactive.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/examples/eav.sql b/tests/examples/eav.sql index fe5879b..c5af320 100644 --- a/tests/examples/eav.sql +++ b/tests/examples/eav.sql @@ -22,8 +22,8 @@ INSERT INTO public.measurement_type VALUES (5, 'matter'); CREATE TABLE public.measurement ( id INTEGER NOT NULL, type INTEGER NOT NULL, - first_value INTEGER, - second_value INTEGER, + first_value FLOAT, + second_value FLOAT, third_value TEXT ); diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 5116377..284e04d 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -1490,32 +1490,32 @@ def test_create_with_null_partitioned_grouped_multivariate(self): two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 ) self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) - self.assertAlmostEqual(two.x_var(), 0.21, delta=0.4) + self.assertAlmostEqual(two.x_var(), 0.315, delta=0.18) self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(two.y_var(), 0.07, delta=0.1) - self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) + self.assertAlmostEqual(two.y_var(), 0.105, delta=0.06) + self.assertAlmostEqual(two.covar(), 0.105, delta=0.07) # type 3 self.assertAlmostEqual( three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(two.covar(), -0.5, delta=0.5) + self.assertAlmostEqual(three.covar(), -2.085, delta=1.1) # type 4 self.assertAlmostEqual( four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(two.covar(), 0.5, delta=0.5) + self.assertAlmostEqual(four.covar(), 3.33, delta=1) # type 5/fish self.assertAlmostEqual( fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) + self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.6) # type 5/fowl self.assertAlmostEqual( fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): """Test EAV for all columns with sampled and suppressed generation.""" @@ -1618,13 +1618,13 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.57, delta=0.8) + self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.5) # type 5/fowl self.assertAlmostEqual( fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.24, delta=1.5) + self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) stmt = select(self.metadata.tables[table2_name]) rows = conn.execute(stmt).fetchall() firsts = Stat() From 20ff9eddebc84f6df5f23de941471e3fb7e06107 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 6 Oct 2025 16:59:55 +0100 Subject: [PATCH 06/35] Mypy fixes in generators.py --- datafaker/generators.py | 379 +++++++++++++++++++++++----------------- 1 file changed, 216 insertions(+), 163 deletions(-) diff --git a/datafaker/generators.py b/datafaker/generators.py index 0a3f8c0..dc1fdc0 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -11,17 +11,20 @@ from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Any, Callable, Iterable, Self, TypeVar +from typing import Any, Callable, Iterable, Sequence, TypeVar, Union +from typing_extensions import Self import mimesis import mimesis.locales import sqlalchemy -from sqlalchemy import Column, Connection, Engine, RowMapping, Sequence, text -from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time +from sqlalchemy import Column, Connection, CursorResult, Engine, RowMapping, text +from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time, TypeEngine from datafaker.base import DistributionGenerator from datafaker.utils import logger +numeric = Union[int, float] + # How many distinct values can we have before we consider a # choice distribution to be infeasible? MAXIMUM_CHOICES = 500 @@ -120,7 +123,7 @@ def generate_data(self, count: int) -> list[Any]: Generate 'count' random data points for this column. """ - def fit(self, default: float = None) -> float | None: + def fit(self, default: float | None=None) -> float | None: """ Return a value representing how well the distribution fits the real source data. @@ -179,7 +182,7 @@ def __init__( self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) # Need to deal with this somehow (or remove it from the schema) self._argn: list[str] = generator_object.get("args", []) - self._select_aggregate_clauses = {} + self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} self._custom_queries = {} for sstat in config.get("src-stats", []): name: str = sstat["name"] @@ -246,7 +249,7 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -278,7 +281,7 @@ def __init__( ) ) ) - self.buckets = [0] * 10 + self.buckets: Sequence[int] = [0] * 10 for rb in raw_buckets: if rb.b is not None: bucket = min(9, max(0, int(rb.b) + 1)) @@ -288,7 +291,7 @@ def __init__( @classmethod def make_buckets( - _cls, engine: Engine, table_name: str, column_name: str + cls, engine: Engine, table_name: str, column_name: str ) -> Self | None: """ Construct a Buckets object. @@ -308,23 +311,23 @@ def make_buckets( ) ) ).first() - if result is None or result.stddev is None or result.count < 2: + if result is None or result.stddev is None or getattr(result, "count") < 2: return None try: - buckets = Buckets( + buckets = cls( engine, table_name, column_name, result.mean, result.stddev, - result.count, + getattr(result, "count"), ) except sqlalchemy.exc.DatabaseError as exc: logger.debug("Failed to instantiate Buckets object: %s", exc) return None return buckets - def fit_from_counts(self, bucket_counts: list[float]) -> float: + def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: """ Figure out the fit from bucket counts from the generator distribution. """ @@ -350,7 +353,7 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: return [ generator for factory in self.factories @@ -383,10 +386,10 @@ def __init__( self._name = "generic." + function_name self._generator_function = f - def function_name(self): + def function_name(self) -> str: return self._name - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [self._generator_function() for _ in range(count)] @@ -415,16 +418,16 @@ def __init__( samples = [value_fn(s) for s in samples] self._fit = buckets.fit_from_values(samples) - def function_name(self): + def function_name(self) -> str: return self._name - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return {} - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return {} - def fit(self, default=None): + def fit(self, default: float | None=None) -> float | None: return default if self._fit is None else self._fit @@ -445,21 +448,21 @@ def function_name(self) -> str: def name(self) -> str: return f"{self._name} [truncated to {self._length}]" - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "subgen_fn": self._name, "params": {}, "length": self._length, } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "subgen_fn": self._name, "params": {}, "length": self._length, } - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [self._generator_function()[: self._length] for _ in range(count)] @@ -472,7 +475,7 @@ def __init__( max_year: str, start: int, end: int, - ): + ) -> None: """ :param column: The column to generate into :param function_name: The name of the mimesis function @@ -489,7 +492,7 @@ def __init__( self._end = end @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str): + def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> list[Generator]: extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -512,13 +515,13 @@ def make_singleton(_cls, column: Column, engine: Engine, function_name: str): ) ] - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "start": self._start, "end": self._end, @@ -536,14 +539,14 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [ self._generator_function(start=self._start, end=self._end) for _ in range(count) ] -def get_column_type(column: Column): +def get_column_type(column: Column) -> TypeEngine: try: return column.type.as_generic() except NotImplementedError: @@ -589,7 +592,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -632,7 +635,7 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -653,7 +656,7 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -668,7 +671,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -685,7 +688,7 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -700,7 +703,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -710,26 +713,28 @@ def get_generators(self, columns: list[Column], engine: Engine): return [MimesisGenerator("person.weight")] -def fit_from_buckets(xs: list[float], ys: list[float]): +def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) return sum_diff_squared / (count * count) class ContinuousDistributionGenerator(Generator): + expected_buckets: Sequence[numeric] = [] + def __init__(self, table_name: str, column_name: str, buckets: Buckets): super().__init__() self.table_name = table_name self.column_name = column_name self.buckets = buckets - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: if self.buckets is None: return {} return { @@ -751,7 +756,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None): + def fit(self, default: float | None=None) -> float | None: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -771,10 +776,10 @@ class GaussianGenerator(ContinuousDistributionGenerator): 0.0227, ] - def function_name(self): + def function_name(self) -> str: return "dist_gen.normal" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [ dist_gen.normal(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -795,10 +800,10 @@ class UniformGenerator(ContinuousDistributionGenerator): 0, ] - def function_name(self): + def function_name(self) -> str: return "dist_gen.uniform_ms" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [ dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -822,7 +827,7 @@ def _get_generators_from_buckets( UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -869,19 +874,19 @@ def __init__( self.logmean = logmean self.logstddev = logstddev - def function_name(self): + def function_name(self) -> str: return "dist_gen.lognormal" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', } - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "logmean": self.logmean, "logsd": self.logstddev, @@ -901,7 +906,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None): + def fit(self, default: float | None=None) -> float | None: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -941,7 +946,15 @@ def _get_generators_from_buckets( ] -def zipf_distribution(total, bins): +def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Get a zipf distribution for a certain number of items distributed + in a certain number of bins. + :param total: The total number of items to be distributed. + :param bins: The total number of bins to distribute the items into. + :return: A generator of the number of items in each bin, from the + largest to the smallest. + """ basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) bd_remaining = sum(basic_dist) for b in basic_dist: @@ -960,13 +973,13 @@ class ChoiceGenerator(Generator): def __init__( self, - table_name, - column_name, - values, - counts, - sample_count=None, - suppress_count=0, - ): + table_name: str, + column_name : str, + values: list[Any], + counts: list[int], + sample_count: int | None=None, + suppress_count: int=0, + ) -> None: super().__init__() self.table_name = table_name self.column_name = column_name @@ -1001,12 +1014,12 @@ def __init__( self._annotation = "sampled and suppressed" @abstractmethod - def get_estimated_counts(self, counts): + def get_estimated_counts(self, counts: list[int]) -> list[int]: """ The counts that we would expect if this distribution was the correct one. """ - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } @@ -1017,7 +1030,7 @@ def name(self) -> str: return n return f"{n} [{self._annotation}]" - def actual_kwargs(self): + def actual_kwargs(self) -> dict[str, Any]: return { "a": self.values, } @@ -1032,7 +1045,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default=None) -> float | None: + def fit(self, default: float | None=None) -> float | None: return default if self._fit is None else self._fit @@ -1049,7 +1062,12 @@ def generate_data(self, count: int) -> list[float]: ] -def uniform_distribution(total, bins: int) -> typing.Generator[int, None, None]: +def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + A generator putting ``total`` items uniformly into ``bins`` bins. + If they don't fit exactly evenly, the earlier bins will have one more + item than the later bins so the total is as required. + """ p = total // bins n = total % bins for _ in range(0, n): @@ -1059,29 +1077,86 @@ def uniform_distribution(total, bins: int) -> typing.Generator[int, None, None]: class UniformChoiceGenerator(ChoiceGenerator): - def get_estimated_counts(self, counts): + """ + A generator producing values, each roughly as frequently as each other. + """ + def get_estimated_counts(self, counts: list[int]) -> list[int]: return list(uniform_distribution(sum(counts), len(counts))) - def function_name(self): + def function_name(self) -> str: return "dist_gen.choice" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [dist_gen.choice(self.values) for _ in range(count)] class WeightedChoiceGenerator(ChoiceGenerator): STORE_COUNTS = True - def get_estimated_counts(self, counts): + def get_estimated_counts(self, counts: list[int]) -> list[int]: return counts - def function_name(self): + def function_name(self) -> str: return "dist_gen.weighted_choice" - def generate_data(self, count): + def generate_data(self, count: int) -> list[Any]: return [dist_gen.weighted_choice(self.values) for _ in range(count)] +class ValueGatherer: + """ + Gathers values from a query of values and counts. + + The query must return columns ``v`` for a value and ``f`` for the + count of how many of those values there are. + These values will be gathered into a number of properties: + ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts + in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and + ``count`` giving these values and counts. ``counts_not_suppressed``, + ``values_not_suppressed`` and ``cvs_not_suppressed`` are the + equivalents with the counts less than or equal to ``suppress_count`` + removed. + + :param suppress_count: value with a count of this or fewer will be excluded + from the suppressed values. + """ + def __init__(self, results: CursorResult, suppress_count: int=0) -> None: + values = [] # All values found + counts = [] # The number or each value + cvs: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = ( + [] + ) # All values found more than SUPPRESS_COUNT times + counts_not_suppressed = [] # The number for each value not suppressed + cvs_not_suppressed: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + for result in results: + c = result.f + if c != 0: + counts.append(c) + v = result.v + if type(v) is decimal.Decimal: + v = float(v) + values.append(v) + cvs.append({"value": v, "count": c}) + if suppress_count < c: + counts_not_suppressed.append(c) + v = result.v + if type(v) is decimal.Decimal: + v = float(v) + values_not_suppressed.append(v) + cvs_not_suppressed.append({"value": v, "count": c}) + self.values = values + self.counts = counts + self.cvs = cvs + self.values_not_suppressed = values_not_suppressed + self.counts_not_suppressed = counts_not_suppressed + self.cvs_not_suppressed = cvs_not_suppressed + + class ChoiceGeneratorFactory(GeneratorFactory): """ All generators that want an average and standard deviation. @@ -1090,7 +1165,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1108,25 +1183,12 @@ def get_generators(self, columns: list[Column], engine: Engine): ) ) if results is not None and results.rowcount <= MAXIMUM_CHOICES: - values = [] # The values found - counts = [] # The number or each value - cvs: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if counts: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, values, counts), - UniformChoiceGenerator(table_name, column_name, values, counts), - WeightedChoiceGenerator(table_name, column_name, cvs, counts), + ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), + UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), + WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), ] results = connection.execute( text( @@ -1138,81 +1200,59 @@ def get_generators(self, columns: list[Column], engine: Engine): ) ) if results is not None: - values = [] # All values found - counts = [] # The number or each value - cvs: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = ( - [] - ) # All values found more than SUPPRESS_COUNT times - counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if self.SUPPRESS_COUNT < c: - counts_not_suppressed.append(c) - v = result.v - if type(v) is decimal.Decimal: - v = float(v) - values_not_suppressed.append(v) - cvs_not_suppressed.append({"value": v, "count": c}) - if counts: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), + UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), + WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), + ] generators += [ ZipfChoiceGenerator( table_name, column_name, - values, - counts, + vg.values, + vg.counts, sample_count=self.SAMPLE_COUNT, ), UniformChoiceGenerator( table_name, column_name, - values, - counts, + vg.values, + vg.counts, sample_count=self.SAMPLE_COUNT, ), WeightedChoiceGenerator( table_name, column_name, - cvs, - counts, + vg.cvs, + vg.counts, sample_count=self.SAMPLE_COUNT, ), ] - if counts_not_suppressed: + if vg.counts_not_suppressed: generators += [ ZipfChoiceGenerator( table_name, column_name, - values_not_suppressed, - counts_not_suppressed, + vg.values_not_suppressed, + vg.counts_not_suppressed, sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), UniformChoiceGenerator( table_name, column_name, - values_not_suppressed, - counts_not_suppressed, + vg.values_not_suppressed, + vg.counts_not_suppressed, sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), WeightedChoiceGenerator( table_name=table_name, column_name=column_name, - values=cvs_not_suppressed, - counts=counts, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, sample_count=self.SAMPLE_COUNT, suppress_count=self.SUPPRESS_COUNT, ), @@ -1221,7 +1261,7 @@ def get_generators(self, columns: list[Column], engine: Engine): class ConstantGenerator(Generator): - def __init__(self, value): + def __init__(self, value: Any) -> None: super().__init__() self.value = value self.repr = repr(value) @@ -1235,7 +1275,7 @@ def nominal_kwargs(self) -> dict[str, str]: def actual_kwargs(self) -> dict[str, Any]: return {"value": self.value} - def generate_data(self, count) -> list[Any]: + def generate_data(self, count: int) -> list[Any]: return [self.value for _ in range(count)] @@ -1244,7 +1284,7 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine): + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1261,29 +1301,32 @@ def get_generators(self, columns: list[Column], engine: Engine): class MultivariateNormalGenerator(Generator): + """ + Generator of multiple values drawn from a multivariate normal distribution. + """ def __init__( self, - table_name: list[str], + table_name: str, column_names: list[str], query: str, - covariates: dict[str, float], + covariates: RowMapping, function_name: str, - ): + ) -> None: self._table = table_name self._columns = column_names self._query = query self._covariates = covariates self._function_name = function_name - def function_name(self): + def function_name(self) -> str: return "dist_gen." + self._function_name - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', } - def custom_queries(self): + def custom_queries(self) -> dict[str, Any]: cols = ", ".join(self._columns) return { f"auto__cov__{self._table}": { @@ -1298,7 +1341,7 @@ def actual_kwargs(self) -> dict[str, Any]: """ return {"cov": self._covariates} - def generate_data(self, count) -> list[Any]: + def generate_data(self, count: int) -> list[Any]: """ Generate 'count' random data points for this column. """ @@ -1307,7 +1350,7 @@ def generate_data(self, count) -> list[Any]: for _ in range(count) ] - def fit(self, default=None) -> float | None: + def fit(self, default: float | None=None) -> float | None: return default @@ -1375,7 +1418,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1417,17 +1460,23 @@ def query_var(self, column: str) -> str: return f"LN({column})" -def text_list(items: list[str]) -> str: +def text_list(items: Iterable[str]) -> str: """ Concatenate the items with commas and one "and". """ - if not hasattr(items, "__getitem__"): - items = list(items) - if len(items) == 0: + item_i = iter(items) + try: + last_item = next(item_i) + except StopIteration: return "" - if len(items) == 1: - return items[0] - return ", ".join(items[:-1]) + " and " + items[-1] + try: + so_far = next(item_i) + except StopIteration: + return last_item + for item in item_i: + so_far += ", " + last_item + last_item = item + return so_far + " and " + last_item @dataclass @@ -1448,7 +1497,7 @@ class RowPartition: # added): {index: value} constant_outputs: dict[int, Any] # The actual covariates from the source database - covariates: list[dict[str, float]] + covariates: Sequence[RowMapping] def comment(self) -> str: caveat = "" @@ -1495,7 +1544,7 @@ def __init__( function_name: str = "grouped_multivariate_lognormal", name_suffix: str | None = None, partition_count_query: str | None = None, - partition_counts: Sequence[RowMapping] | None = None, + partition_counts: Iterable[RowMapping] = [], partition_count_comment: str | None = None, ): self._query_name = query_name @@ -1515,7 +1564,7 @@ def name(self) -> str: def function_name(self) -> str: return "dist_gen.alternatives" - def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition): + def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) -> dict[str, Any]: count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' if not partition.included_numeric and not partition.included_choice: return { @@ -1542,12 +1591,10 @@ def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) }, } - def _count_query_name(self): - if self._partition_count_query: - return f"auto__cov__{self._query_name}__counts" - return None + def _count_query_name(self) -> str: + return f"auto__cov__{self._query_name}__counts" - def nominal_kwargs(self): + def nominal_kwargs(self) -> dict[str, Any]: return { "alternative_configs": [ self._nominal_kwargs_with_combinations(index, self._partitions[index]) @@ -1556,7 +1603,7 @@ def nominal_kwargs(self): "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', } - def custom_queries(self): + def custom_queries(self) -> dict[str, Any]: partitions = { f"auto__cov__{self._query_name}__alt_{index}": { "comment": partition.comment(), @@ -1574,7 +1621,7 @@ def custom_queries(self): **partitions, } - def _actual_kwargs_with_combinations(self, partition: RowPartition): + def _actual_kwargs_with_combinations(self, partition: RowPartition) -> dict[str, Any]: count = sum(row["count"] for row in partition.covariates) if not partition.included_numeric and not partition.included_choice: return { @@ -1614,14 +1661,14 @@ def actual_kwargs(self) -> dict[str, Any]: "counts": self._partition_counts, } - def generate_data(self, count) -> list[Any]: + def generate_data(self, count: int) -> list[Any]: """ Generate 'count' random data points for this column. """ kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] - def fit(self, default=None) -> float | None: + def fit(self, default: float | None=None) -> float | None: return default @@ -1690,6 +1737,12 @@ def __init__( class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 + EMPTY_RESULT = [RowMapping( + parent=sqlalchemy.engine.result.ResultMetaData(), + processors=None, + key_to_index={"count": 0}, + data=(0,) + )] def function_name(self) -> str: return "grouped_multivariate_normal" @@ -1739,7 +1792,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) @@ -1767,7 +1820,7 @@ def get_generators(self, columns: list[Column], engine: Engine) -> list[Generato partition_def.included_choice, partition_def.excluded, partition_def.nones, - {}, + [], ) query = self.query( table=table, @@ -1785,7 +1838,7 @@ def get_generators(self, columns: list[Column], engine: Engine) -> list[Generato partition_def.included_choice, partition_def.excluded, partition_def.nones, - {}, + [], ) gens: list[Generator] = [] try: @@ -1846,7 +1899,7 @@ def _execute_partition_queries( for rp in partitions.values(): rp.covariates = connection.execute(text(rp.query)).mappings().fetchall() if not rp.covariates or rp.covariates[0]["count"] is None: - rp.covariates = [{"count": 0}] + rp.covariates = self.EMPTY_RESULT else: found_nonzero = True return found_nonzero From 49954de69cea1fb005369413631919320e1f9e9c Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 6 Oct 2025 18:59:32 +0100 Subject: [PATCH 07/35] mypy fixes in interactive.py --- datafaker/base.py | 20 ++- datafaker/generators.py | 40 ++--- datafaker/interactive.py | 336 +++++++++++++++++++++++++-------------- datafaker/utils.py | 4 +- 4 files changed, 250 insertions(+), 150 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 0a3cf41..acdc9b9 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -8,7 +8,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Generator, TypeVar +from typing import Any, Callable, Generator import numpy as np import yaml @@ -18,22 +18,20 @@ from datafaker.utils import ( MAKE_VOCAB_PROGRESS_REPORT_EVERY, + T, logger, stream_yaml, table_row_count, ) -_T = TypeVar("_T") - - @functools.cache def zipf_weights(size: int) -> list[float]: total = sum(map(lambda n: 1 / n, range(1, size + 1))) return [1 / (n * total) for n in range(1, size + 1)] -def merge_with_constants(xs: list[_T], constants_at: dict[int, _T]) -> Generator[_T, None, None]: +def merge_with_constants(xs: list[T], constants_at: dict[int, T]) -> Generator[T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. :param constants_at: A map of indices to objects that must be placed at @@ -86,11 +84,11 @@ def normal(self, mean: float, sd: float) -> float: def lognormal(self, logmean: float, logsd: float) -> float: return random.lognormvariate(float(logmean), float(logsd)) - def choice(self, a: list[_T]) -> _T: + def choice(self, a: list[T]) -> T: c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c - def zipf_choice(self, a: list[_T], n: int | None=None) -> _T: + def zipf_choice(self, a: list[T], n: int | None=None) -> T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] @@ -113,7 +111,7 @@ def weighted_choice(self, a: list[dict[str, Any]]) -> Any: c = random.choices(vs, weights=counts)[0] return c - def constant(self, value: _T) -> _T: + def constant(self, value: T) -> T: return value def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: @@ -260,8 +258,8 @@ def alternatives( return getattr(self, name)(**alt["params"]) def with_constants_at( - self, constants_at: dict[int, _T], subgen: str, params: dict[str, _T] - ) -> list[_T]: + self, constants_at: dict[int, T], subgen: str, params: dict[str, T] + ) -> list[T]: if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -272,7 +270,7 @@ def with_constants_at( logger.debug("Merging constants %s", constants_at) return list(merge_with_constants(subout, constants_at)) - def truncated_string(self, subgen_fn: Callable[..., list[_T]], params: dict, length: int) -> list[_T]: + def truncated_string(self, subgen_fn: Callable[..., list[T]], params: dict, length: int) -> list[T]: """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: diff --git a/datafaker/generators.py b/datafaker/generators.py index dc1fdc0..57a3c64 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -123,7 +123,7 @@ def generate_data(self, count: int) -> list[Any]: Generate 'count' random data points for this column. """ - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: """ Return a value representing how well the distribution fits the real source data. @@ -249,7 +249,7 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -353,7 +353,7 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: return [ generator for factory in self.factories @@ -427,7 +427,7 @@ def nominal_kwargs(self) -> dict[str, Any]: def actual_kwargs(self) -> dict[str, Any]: return {} - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default if self._fit is None else self._fit @@ -592,7 +592,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -635,7 +635,7 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -656,7 +656,7 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -671,7 +671,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -688,7 +688,7 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -703,7 +703,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -756,7 +756,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -827,7 +827,7 @@ def _get_generators_from_buckets( UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -906,7 +906,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -1045,7 +1045,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default if self._fit is None else self._fit @@ -1165,7 +1165,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1284,7 +1284,7 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1350,7 +1350,7 @@ def generate_data(self, count: int) -> list[Any]: for _ in range(count) ] - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default @@ -1418,7 +1418,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1668,7 +1668,7 @@ def generate_data(self, count: int) -> list[Any]: kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] - def fit(self, default: float | None=None) -> float | None: + def fit(self, default: float=-1) -> float: return default @@ -1792,7 +1792,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 4f5e8c5..ee9b940 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -3,10 +3,12 @@ import functools import re from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Collection, Mapping from dataclasses import dataclass from enum import Enum from pathlib import Path +from typing import Any, Callable, Iterable, cast +from typing_extensions import Self import sqlalchemy from prettytable import PrettyTable @@ -108,12 +110,12 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: ... def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] ): super().__init__() - self.config = config + self.config: Mapping[str, Any] = config self.metadata = metadata - self.table_entries: list[TableEntry] = [] + self._table_entries: Collection[TableEntry] = [] tables_config: Mapping = config.get("tables", {}) if type(tables_config) is not dict: tables_config = {} @@ -127,16 +129,16 @@ def __init__( self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.engine.dispose() - def print(self, text: str, *args, **kwargs): + def print(self, text: str, *args, **kwargs) -> None: print(text.format(*args, **kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]): + def print_table(self, headings: list[str], rows: list[list[str]]) -> None: output = PrettyTable() output.field_names = headings for row in rows: @@ -159,7 +161,7 @@ def ask_save(self): return ask.result def set_table_index(self, index) -> bool: - if 0 <= index and index < len(self.table_entries): + if 0 <= index and index < len(self._table_entries): self.table_index = index self.set_prompt() return True @@ -172,7 +174,7 @@ def next_table(self, report="No more tables"): return True def table_name(self): - return self.table_entries[self.table_index].name + return self._table_entries[self.table_index].name def table_metadata(self) -> Table: return self.metadata.tables[self.table_name()] @@ -195,21 +197,21 @@ def report_columns(self): ], ) - def get_table_config(self, table_name: str) -> dict[str, any]: + def get_table_config(self, table_name: str) -> dict[str, Any]: ts = self.config.get("tables", None) if type(ts) is not dict: return {} t = ts.get(table_name) return t if type(t) is dict else {} - def set_table_config(self, table_name: str, config: dict[str, any]): + def set_table_config(self, table_name: str, config: dict[str, Any]): ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} return ts[table_name] = config - def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: + def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: src_stats = self.config.get("src-stats", []) new_src_stats = [] for stat in src_stats: @@ -218,7 +220,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, any]]: self.config["src-stats"] = new_src_stats return new_src_stats - def get_nonnull_columns(self, table_name: str): + def get_nonnull_columns(self, table_name: str) -> list[str]: metadata_table = self.metadata.tables[table_name] return [ str(name) @@ -230,21 +232,21 @@ def find_entry_index_by_table_name(self, table_name) -> int | None: return next( ( i - for i, entry in enumerate(self.table_entries) + for i, entry in enumerate(self._table_entries) if entry.name == table_name ), None, ) def find_entry_by_table_name(self, table_name) -> TableEntry | None: - for e in self.table_entries: + for e in self._table_entries: if e.name == table_name: return e return None def do_counts(self, _arg): "Report the column names with the counts of nulls in them" - if len(self.table_entries) <= self.table_index: + if len(self._table_entries) <= self.table_index: return table_name = self.table_name() nonnull_columns = self.get_nonnull_columns(table_name) @@ -289,14 +291,14 @@ def do_select(self, arg): rows = [row._tuple() for row in result.fetchmany(MAX_SELECT_ROWS)] self.print_table(fields, rows) - def do_peek(self, arg: str): + def do_peek(self, arg: str) -> None: """ Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. Use 'peek' to see a sample of the current column(s). Rows that are enitrely null are suppressed. """ MAX_PEEK_ROWS = 25 - if len(self.table_entries) <= self.table_index: + if len(self._table_entries) <= self.table_index: return table_name = self.table_name() col_names = arg.split() @@ -319,8 +321,8 @@ def do_peek(self, arg: str): rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] self.print_table(list(result.keys()), rows) - def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int): - if len(self.table_entries) <= self.table_index: + def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + if len(self._table_entries) <= self.table_index: return [] return [ col for col in self.table_metadata().columns.keys() if col.startswith(text) @@ -369,18 +371,22 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: def __init__( self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping - ): + ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() - def set_prompt(self): + @property + def table_entries(self) -> list[TableCmdTableEntry]: + return cast(TableCmdTableEntry, self._table_entries) + + def set_prompt(self) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) else: self.prompt = "(table) " - def set_type(self, t_type: TableType): + def set_type(self, t_type: TableType) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type @@ -428,7 +434,6 @@ def _sanity_check_failures(self) -> list[tuple[str, str, str]]: """Find tables that reference each other that should not given their types.""" failures = [] for from_entry in self.table_entries: - from_entry: TableCmdTableEntry from_t = from_entry.new_type if from_t == TableType.VOCABULARY: referenced = self._get_referenced_tables(from_entry.name) @@ -451,7 +456,6 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: """Find tables that reference each other that might cause problems given their types.""" warnings = [] for from_entry in self.table_entries: - from_entry: TableCmdTableEntry from_t = from_entry.new_type if from_t in {TableType.GENERATE, TableType.PRIVATE}: referenced = self._get_referenced_tables(from_entry.name) @@ -470,7 +474,7 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: ) return warnings - def do_quit(self, _arg): + def do_quit(self, _arg: str) -> bool: "Check the updates, save them if desired and quit the configurer." count = 0 for entry in self.table_entries: @@ -502,7 +506,7 @@ def do_quit(self, _arg): return True return False - def do_tables(self, _arg): + def do_tables(self, _arg: str) -> None: "list the tables with their types" for entry in self.table_entries: old = entry.old_type @@ -510,7 +514,7 @@ def do_tables(self, _arg): becomes = " " if old == new else "->" + TYPE_LETTER[new] self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) - def do_next(self, arg): + def do_next(self, arg: str) -> None: "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: # Find the index of the table called _arg, if any @@ -522,51 +526,51 @@ def do_next(self, arg): return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text, line, begidx, endidx): + def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] - def do_previous(self, _arg): + def do_previous(self, _arg: str) -> None: "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) - def do_ignore(self, _arg): + def do_ignore(self, _arg: str) -> None: "Set the current table as ignored, and go to the next table" self.set_type(TableType.IGNORE) self.print("Table {} set as ignored", self.table_name()) self.next_table() - def do_vocabulary(self, _arg): + def do_vocabulary(self, _arg: str) -> None: "Set the current table as a vocabulary table, and go to the next table" self.set_type(TableType.VOCABULARY) self.print("Table {} set to be a vocabulary table", self.table_name()) self.next_table() - def do_private(self, _arg): + def do_private(self, _arg: str) -> None: "Set the current table as a primary private table (such as the table of patients)" self.set_type(TableType.PRIVATE) self.print("Table {} set to be a primary private table", self.table_name()) self.next_table() - def do_generate(self, _arg): + def do_generate(self, _arg: str) -> None: "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() - def do_empty(self, _arg): + def do_empty(self, _arg: str) -> None: "Set the current table as empty; no generators will be run for it" self.set_type(TableType.EMPTY) self.print("Table {} empty", self.table_name()) self.next_table() - def do_columns(self, _arg): + def do_columns(self, _arg: str) -> None: "Report the column names and metadata" self.report_columns() - def do_data(self, arg: str): + def do_data(self, arg: str) -> None: """ Report some data. 'data' = report a random ten lines, @@ -606,14 +610,14 @@ def do_data(self, arg: str): number = 48 self.print_column_data(column, number, min_length) - def complete_data(self, text, line, begidx, endidx): + def complete_data(self, text: str, line: str, begidx: int, _endidx: int) -> list[str]: previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] table_metadata = self.table_metadata() return [k for k in table_metadata.columns.keys() if k.startswith(text)] - def print_column_data(self, column: str, count: int, min_length: int): + def print_column_data(self, column: str, count: int, min_length: int) -> None: where = f"WHERE {column} IS NOT NULL" if 0 < min_length: where = "WHERE LENGTH({column}) >= {len}".format( @@ -633,7 +637,7 @@ def print_column_data(self, column: str, count: int, min_length: int): ) self.columnize([str(x[0]) for x in result.all()]) - def print_row_data(self, count: int): + def print_row_data(self, count: int) -> None: with self.engine.connect() as connection: result = connection.execute( text( @@ -651,7 +655,7 @@ def print_row_data(self, count: int): def update_config_tables( src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping -): +) -> Mapping[str, Any]: with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() return tc.config @@ -671,7 +675,7 @@ class MissingnessType: columns: list[str] @classmethod - def sampled_query(cls, table, count, column_names) -> str: + def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) column_is_nulls = ", ".join( ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] @@ -713,9 +717,9 @@ def find_missingness_query( for src_stat in self.config["src-stats"]: if src_stat.get("name") == key: return (src_stat.get("query", None), src_stat.get("comment", None)) - return None + return None - def make_table_entry(self, name: str, table: Mapping) -> TableEntry: + def make_table_entry(self, name: str, table: Mapping) -> TableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -737,7 +741,7 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: elif len(mgs) == 1: mg = mgs[0] mg_name = mg.get("name", None) - if mg_name is not None: + if type(mg_name) is str: query_comment = self.find_missingness_query(mg) if query_comment is not None: (query, comment) = query_comment @@ -758,10 +762,24 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: def __init__( self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping ): + """ + Initialise a MissingnessCmd. + :param src_dsn: connection string for the source database. + :param src_schema: schema name for the source database. + :param metadata: SQLAlchemy metadata for the source database. + :param config: Configuration from the ``config.yaml`` file. + """ super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() - def set_prompt(self): + @property + def table_entries(self) -> list[MissingnessCmdTableEntry]: + return cast(MissingnessCmdTableEntry, self._table_entries) + + def set_prompt(self) -> None: + """ + Sets the prompt according to the current table and missingness. + """ if self.table_index < len(self.table_entries): entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] nt = entry.new_type @@ -772,7 +790,7 @@ def set_prompt(self): else: self.prompt = "(missingness) " - def set_type(self, t_type: TableType): + def set_type(self, t_type: TableType) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type @@ -780,7 +798,6 @@ def set_type(self, t_type: TableType): def _copy_entries(self) -> None: src_stats = self._remove_prefix_src_stats("missing_auto__") for entry in self.table_entries: - entry: MissingnessCmdTableEntry table = self.get_table_config(entry.name) if entry.new_type is None or entry.new_type.name == "none": table.pop("missingness_generators", None) @@ -808,7 +825,7 @@ def _copy_entries(self) -> None: ) self.set_table_config(entry.name, table) - def do_quit(self, _arg): + def do_quit(self, _arg: str) -> bool: "Check the updates, save them if desired and quit the configurer." count = 0 for entry in self.table_entries: @@ -843,7 +860,7 @@ def do_quit(self, _arg): return True return False - def do_tables(self, arg): + def do_tables(self, _arg: str) -> None: "list the tables with their types" for entry in self.table_entries: old = "-" if entry.old_type is None else entry.old_type.name @@ -851,7 +868,7 @@ def do_tables(self, arg): desc = new if old == new else "{0}->{1}".format(old, new) self.print("{0} {1}", entry.name, desc) - def do_next(self, arg): + def do_next(self, arg: str) -> None: "'next' = go to the next table, 'next tablename' = go to table 'tablename'" if arg: # Find the index of the table called _arg, if any @@ -866,17 +883,20 @@ def do_next(self, arg): return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text, line, begidx, endidx): + def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] - def do_previous(self, _arg): + def do_previous(self, _arg: str) -> None: "Go to the previous table" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) - def _set_type(self, name, query, comment): + def _set_type(self, name: str, query: str, comment: str) -> None: + """ + Set the current table entry's query. + """ if len(self.table_entries) <= self.table_index: return entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] @@ -887,13 +907,15 @@ def _set_type(self, name, query, comment): columns=self.get_nonnull_columns(entry.name), ) - def _set_none(self): + def _set_none(self) -> None: + """ + Sets the current table to have no missingness applied. + """ if len(self.table_entries) <= self.table_index: return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - entry.new_type = None + self.table_entries[self.table_index].new_type = None - def do_sampled(self, arg: str): + def do_sampled(self, arg: str) -> None: """ Set the current table missingness as 'sampled', and go to the next table. "sampled 3000" means sample 3000 rows at random and choose the missingness @@ -903,7 +925,7 @@ def do_sampled(self, arg: str): if len(self.table_entries) <= self.table_index: self.print("Error! not on a table") return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + entry = self.table_entries[self.table_index] if arg == "": count = 1000 elif arg.isdecimal(): @@ -926,7 +948,7 @@ def do_sampled(self, arg: str): self.print("Table {} set to sampled missingness", self.table_name()) self.next_table() - def do_none(self, _arg): + def do_none(self, _arg: str) -> None: "Set the current table to have no missingness, and go to the next table" self._set_none() self.print("Table {} set to have no missingness", self.table_name()) @@ -934,8 +956,8 @@ def do_none(self, _arg): def update_missingness( - src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping -): + src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] +) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() return mc.config @@ -943,17 +965,28 @@ def update_missingness( @dataclass class GeneratorInfo: + """ + A generator and the columns it assigns to. + """ columns: list[str] gen: Generator | None @dataclass class GeneratorCmdTableEntry(TableEntry): + """ + List of generators set for a table. + Includes the original setting and the currently configured + generators. + """ old_generators: list[GeneratorInfo] new_generators: list[GeneratorInfo] class GeneratorCmd(DbCmd): + """ + Interactive command shell for setting generators. + """ intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the current column, then command 'compare' to see how these perform @@ -1052,21 +1085,40 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping - ): + self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + ) -> None: + """ + Initialise a GeneratorCmd + :param src_dsn: connection address for source database + :param src_schema: database schema name + :param metadata: SQLAlchemy metadata for the source database + :param config: Configuration loaded from ``config.yaml`` + """ super().__init__(src_dsn, src_schema, metadata, config) self.generator_index = 0 self.generators_valid_columns = None self.set_prompt() - def set_table_index(self, index): + @property + def table_entries(self) -> list[GeneratorCmdTableEntry]: + return cast(GeneratorCmdTableEntry, self._table_entries) + + def set_table_index(self, index: int) -> bool: + """ + Moves to a new table. + :param index: table index to move to. + """ ret = super().set_table_index(index) if ret: self.generator_index = 0 self.set_prompt() return ret - def previous_table(self): + def previous_table(self) -> bool: + """ + Move to the table before the current one. + :return: True if there is a previous table to go to. + """ ret = self.set_table_index(self.table_index - 1) if ret: table = self.get_table() @@ -1082,29 +1134,44 @@ def previous_table(self): return ret def get_table(self) -> GeneratorCmdTableEntry | None: + """ + Get the current table entry. + """ if self.table_index < len(self.table_entries): return self.table_entries[self.table_index] return None def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + """ + Gets a pair; the table name then the generator information. + """ if self.table_index < len(self.table_entries): - entry: GeneratorCmdTableEntry = self.table_entries[self.table_index] + entry = self.table_entries[self.table_index] if self.generator_index < len(entry.new_generators): return (entry.name, entry.new_generators[self.generator_index]) return (entry.name, None) return (None, None) def get_column_names(self) -> list[str]: + """ + Gets the (unqualified) names for all the current columns. + """ (_, generator_info) = self.get_table_and_generator() return generator_info.columns if generator_info else [] def column_metadata(self) -> list[Column]: + """ + Gets the metadata for all the current columns. + """ table = self.table_metadata() if table is None: return [] return [table.columns[name] for name in self.get_column_names()] - def set_prompt(self): + def set_prompt(self) -> None: + """ + Set the prompt according to the current table, column and generator. + """ (table_name, gen_info) = self.get_table_and_generator() if table_name is None: self.prompt = "(generators) " @@ -1119,13 +1186,18 @@ def set_prompt(self): gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - def _remove_auto_src_stats(self) -> list[dict[str, any]]: + def _remove_auto_src_stats(self) -> list[dict[str, Any]]: + """ + Remove all automatic source stats (which we assume is + every source stats query whose name begins with ``auto__`)""" return self._remove_prefix_src_stats("auto__") def _copy_entries(self) -> None: + """ + Set generator and query information in the configuration. + """ src_stats = self._remove_auto_src_stats() - tes: list[GeneratorCmdTableEntry] = self.table_entries - for entry in tes: + for entry in self.table_entries: rgs = [] new_gens: list[Generator] = [] for generator in entry.new_generators: @@ -1173,7 +1245,7 @@ def _copy_entries(self) -> None: self.config["src-stats"] = src_stats def _find_old_generator( - self, entry: GeneratorCmdTableEntry, columns + self, entry: GeneratorCmdTableEntry, columns: Iterable[list] ) -> Generator | None: """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) @@ -1182,12 +1254,12 @@ def _find_old_generator( return gen.gen return None - def do_quit(self, arg): + def do_quit(self, arg: str) -> bool: "Check the updates, save them if desired and quit the configurer." count = 0 for entry in self.table_entries: header_shown = False - g_entry: GeneratorCmdTableEntry = entry + g_entry = cast(GeneratorCmdTableEntry, entry) for gen in g_entry.new_generators: old_gen = self._find_old_generator(g_entry, gen.columns) new_gen = None if gen is None else gen.gen @@ -1215,19 +1287,20 @@ def do_quit(self, arg): return True return False - def do_tables(self, arg): + def do_tables(self, arg: str) -> None: "list the tables" - for entry in self.table_entries: + for t_entry in self.table_entries: + entry = cast(GeneratorCmdTableEntry, t_entry) gen_count = len(entry.new_generators) how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" self.print("{0} ({1})", entry.name, how_many) - def do_list(self, arg): + def do_list(self, arg: str) -> None: "list the generators in the current table" if len(self.table_entries) <= self.table_index: self.print("Error: no table {0}", self.table_index) return - g_entry: GeneratorCmdTableEntry = self.table_entries[self.table_index] + g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) table = self.table_metadata() for gen in g_entry.new_generators: old_gen = self._find_old_generator(g_entry, gen.columns) @@ -1245,11 +1318,11 @@ def do_list(self, arg): primary = "[primary-key]" self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) - def do_columns(self, _arg): + def do_columns(self, _arg: str) -> None: "Report the column names and metadata" self.report_columns() - def do_info(self, _arg): + def do_info(self, _arg: str) -> None: "Show information about the current column" for cm in self.column_metadata(): self.print( @@ -1279,14 +1352,14 @@ def _get_table_index(self, table_name: str) -> int | None: return n return None - def _get_generator_index(self, table_index, column_name): - entry: GeneratorCmdTableEntry = self.table_entries[table_index] + def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + entry = self.table_entries[table_index] for n, gen in enumerate(entry.new_generators): if column_name in gen.columns: return n return None - def go_to(self, target): + def go_to(self, target: str) -> bool: parts = target.split(".", 1) table_index = self._get_table_index(parts[0]) if table_index is None: @@ -1310,7 +1383,7 @@ def go_to(self, target): self.set_prompt() return True - def do_next(self, arg): + def do_next(self, arg: str) -> None: """ Go to the next generator. Or go to a named table: 'next tablename'. @@ -1322,14 +1395,14 @@ def do_next(self, arg): else: self._go_next() - def do_n(self, arg): + def do_n(self, arg: str) -> None: """Synonym for next""" self.do_next(arg) - def complete_n(self, text: str, line: str, begidx: int, endidx: int): + def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: return self.complete_next(text, line, begidx, endidx) - def _go_next(self): + def _go_next(self) -> None: table = self.get_table() if table is None: self.print("No more tables") @@ -1340,7 +1413,7 @@ def _go_next(self): self.generator_index = next_gi self.set_prompt() - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): + def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: parts = text.split(".", 1) first_part = parts[0] if 1 < len(parts): @@ -1348,7 +1421,7 @@ def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): table_index = self._get_table_index(first_part) if table_index is None: return [] - table_entry: GeneratorCmdTableEntry = self.table_entries[table_index] + table_entry = self.table_entries[table_index] return [ f"{first_part}.{column}" for gen in table_entry.new_generators @@ -1374,7 +1447,7 @@ def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int): column_names = [] return table_names + column_names - def do_previous(self, _arg): + def do_previous(self, _arg: str) -> None: """Go to the previous generator""" if self.generator_index == 0: self.previous_table() @@ -1382,17 +1455,24 @@ def do_previous(self, _arg): self.generator_index -= 1 self.set_prompt() - def do_b(self, arg): + def do_b(self, arg: str) -> None: """Synonym for previous""" self.do_previous(arg) def _generators_valid(self) -> bool: + """ + Return True if the self.generators property is still correct for the + table and columns currently being examined. + """ return self.generators_valid_columns == ( self.table_index, self.get_column_names(), ) def _get_generator_proposals(self) -> list[Generator]: + """ + Get a list of acceptable generators, sorted by decreasing fit to the actual data. + """ if not self._generators_valid(): self.generators = None if self.generators is None: @@ -1406,7 +1486,10 @@ def _get_generator_proposals(self) -> list[Generator]: ) return self.generators - def _print_privacy(self): + def _print_privacy(self) -> None: + """ + Print the privacy status of the current table. + """ table = self.table_metadata() if table is None: return @@ -1419,7 +1502,7 @@ def _print_privacy(self): return self.print(self.SECONDARY_PRIVATE_TEXT, pfks) - def do_compare(self, arg: str): + def do_compare(self, arg: str) -> None: """ Compare the real data with some generators. @@ -1448,11 +1531,11 @@ def do_compare(self, arg: str): self._print_values_queried(table_name, n, gen) self.print_table_by_columns(comparison) - def do_c(self, arg): + def do_c(self, arg: str) -> None: """Synonym for compare.""" self.do_compare(arg) - def _print_values_queried(self, table_name: str, n: int, gen: Generator): + def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: """ Print the values queried from the database for this generator. """ @@ -1478,7 +1561,7 @@ def _print_custom_queries(self, gen: Generator) -> None: cqs = gen.custom_queries() if not cqs: return - cq_key2args = {} + cq_key2args: dict[str, Any] = {} nominal = gen.nominal_kwargs() actual = gen.actual_kwargs() self._get_custom_queries_from( @@ -1493,7 +1576,7 @@ def _print_custom_queries(self, gen: Generator) -> None: cq_key2args[cq_key], ) - def _get_custom_queries_from(self, out, nominal, actual): + def _get_custom_queries_from(self, out: dict[str, Any], nominal: Any, actual: Any) -> None: if type(nominal) is str: src_stat_groups = self.SRC_STAT_RE.search(nominal) if src_stat_groups: @@ -1524,7 +1607,7 @@ def _get_aggregate_query( return None return f"SELECT {', '.join(clauses)} FROM {table_name}" - def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: + def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: """ Prints the select aggregate query and all the values it gets in this case. """ @@ -1554,7 +1637,7 @@ def _print_select_aggregate_query(self, table_name, gen: Generator) -> None: select_q = self._get_aggregate_query([gen], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) - def _get_column_data(self, count: int, to_str=repr): + def _get_column_data(self, count: int, to_str: Callable[[Any], str]=repr) -> list[list[str]]: columns = self.get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) @@ -1566,7 +1649,7 @@ def _get_column_data(self, count: int, to_str=repr): ) return [[to_str(x) for x in xs] for xs in result.all()] - def do_propose(self, _arg): + def do_propose(self, _arg: str) -> None: """ Display a list of possible generators for this column. @@ -1585,8 +1668,8 @@ def do_propose(self, _arg): if not gens: self.print(self.PROPOSE_NOTHING) for index, gen in enumerate(gens): - fit = gen.fit() - if fit is None: + fit = gen.fit(-1) + if fit == -1: fit_s = "(no fit)" elif fit < 100: fit_s = f"(fit: {fit:.3g})" @@ -1600,7 +1683,7 @@ def do_propose(self, _arg): sample="; ".join(map(repr, gen.generate_data(limit))), ) - def do_p(self, arg): + def do_p(self, arg: str) -> None: """Synonym for propose""" self.do_propose(arg) @@ -1610,7 +1693,7 @@ def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: return gen return None - def do_set(self, arg: str): + def do_set(self, arg: str) -> None: """ Set one of the proposals as a generator. Takes a single integer argument. @@ -1619,6 +1702,7 @@ def do_set(self, arg: str): self.print("Please run 'propose' before 'set '") return gens = self._get_generator_proposals() + new_gen: Generator | None if arg.isdigit(): index = int(arg) if index < 1: @@ -1639,7 +1723,10 @@ def do_set(self, arg: str): self.set_generator(new_gen) self._go_next() - def set_generator(self, gen: Generator): + def set_generator(self, gen: Generator | None) -> None: + """ + Set the current column's generator. + """ (table, gen_info) = self.get_table_and_generator() if table is None: self.print("Error: no table") @@ -1649,23 +1736,23 @@ def set_generator(self, gen: Generator): return gen_info.gen = gen - def do_s(self, arg): + def do_s(self, arg: str) -> None: """Synonym for set""" self.do_set(arg) - def do_unset(self, _arg): + def do_unset(self, _arg: str) -> None: """ Remove any generator set for this column. """ self.set_generator(None) self._go_next() - def do_merge(self, arg: str): + def do_merge(self, arg: str) -> None: """Add this column(s) to the specified column(s), so one generator covers them all.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return @@ -1713,9 +1800,9 @@ def do_merge(self, arg: str): table_entry.new_generators = new_new_generators self.set_prompt() - def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): + def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: return [] return [ @@ -1726,12 +1813,12 @@ def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int): if column.startswith(last_arg) ] - def do_unmerge(self, arg: str): + def do_unmerge(self, arg: str) -> None: """Remove this column(s) from this generator, make them a separate generator.""" cols = arg.split() if not cols: self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: self.print(self.ERROR_NO_SUCH_TABLE) return @@ -1766,9 +1853,9 @@ def do_unmerge(self, arg: str): ) self.set_prompt() - def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int): + def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: return [] return [ @@ -1782,9 +1869,22 @@ def update_config_generators( src_dsn: str, src_schema: str, metadata: MetaData, - config: Mapping, + config: Mapping[str, Any], spec_path: Path | None, -): +) -> Mapping[str, Any]: + """ + Update configuration with the specification from a CSV file. + The specification is a headerless CSV file with columns: Table name, + Column name (or space-separated list of column names), Generator + name required, Second choice generator name, Third choice generator + name, etcetera. + :param src_dsn: Address of the source database + :param src_schema: Name of the source database schema to read from + :param metadata: SQLAlchemy representation of the source database + :param config: Existing configuration (will not be destructively updated) + :param spec_path: The path of the CSV file containing the specification + :return: Updated configuration. + """ with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: if spec_path is None: gc.cmdloop() diff --git a/datafaker/utils.py b/datafaker/utils.py index 883e096..a94edf7 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -7,7 +7,7 @@ import sys from pathlib import Path from types import ModuleType -from typing import Any, Final, Mapping, Optional, Union +from typing import Any, Final, Mapping, Optional, TypeVar, Union import sqlalchemy import yaml @@ -38,6 +38,8 @@ Path(__file__).parent / "json_schemas/config_schema.json" ) +T = TypeVar("T") + def read_config_file(path: str) -> dict: """Read a config file, warning if it is invalid. From 3e01c574ddaa0ee07780a527f3fa451c9edbd31a Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 6 Oct 2025 21:23:17 +0100 Subject: [PATCH 08/35] More mypy fixes in interactive.py --- datafaker/interactive.py | 121 +++++++++++++++++++++++---------------- 1 file changed, 73 insertions(+), 48 deletions(-) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index ee9b940..c95dd92 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -3,11 +3,12 @@ import functools import re from abc import ABC, abstractmethod -from collections.abc import Collection, Mapping +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Any, Callable, Iterable, cast +from types import TracebackType +from typing import Any, Callable, Iterable, Optional, Type, cast from typing_extensions import Self import sqlalchemy @@ -16,6 +17,7 @@ from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( + T, create_db_engine, fk_refers_to_ignored_table, logger, @@ -30,12 +32,12 @@ import readline if not hasattr(readline, "backend"): - readline.backend = "readline" + setattr(readline, "backend", "readline") except: pass -def or_default(v, d): +def or_default(v: T | None, d: T) -> T: """Returns v if it isn't None, otherwise d.""" return d if v is None else v @@ -75,27 +77,27 @@ class AskSaveCmd(cmd.Cmd): prompt = "(yes/no/cancel) " file = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.result = "" - def do_yes(self, _arg): + def do_yes(self, _arg: str) -> bool: self.result = "yes" return True - def do_no(self, _arg): + def do_no(self, _arg: str) -> bool: self.result = "no" return True - def do_cancel(self, _arg): + def do_cancel(self, _arg: str) -> bool: self.result = "cancel" return True -def fk_column_name(fk: ForeignKey): +def fk_column_name(fk: ForeignKey) -> str: if fk_refers_to_ignored_table(fk): return f"{fk.target_fullname} (ignored)" - return fk.target_fullname + return str(fk.target_fullname) class DbCmd(ABC, cmd.Cmd): @@ -106,16 +108,16 @@ class DbCmd(ABC, cmd.Cmd): ROW_COUNT_MSG = "Total row count: {}" @abstractmethod - def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry: + def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | None: ... def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ): super().__init__() - self.config: Mapping[str, Any] = config + self.config: MutableMapping[str, Any] = config self.metadata = metadata - self._table_entries: Collection[TableEntry] = [] + self._table_entries: list[TableEntry] = [] tables_config: Mapping = config.get("tables", {}) if type(tables_config) is not dict: tables_config = {} @@ -125,56 +127,79 @@ def __init__( table_config = {} entry = self.make_table_entry(name, table_config) if entry is not None: - self.table_entries.append(entry) + self._table_entries.append(entry) self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_val, exc_tb) -> None: + def __exit__( + self, + _exc_type: Optional[Type[BaseException]], + _exc_val: Optional[BaseException], + _exc_tb: Optional[TracebackType], + ) -> None: self.engine.dispose() - def print(self, text: str, *args, **kwargs) -> None: + def print(self, text: str, *args: Any, **kwargs: Any) -> None: print(text.format(*args, **kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]) -> None: + def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: output = PrettyTable() output.field_names = headings for row in rows: output.add_row(row) print(output) - def print_table_by_columns(self, columns: dict[str, list[str]]): + def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: output = PrettyTable() row_count = max([len(col) for col in columns.values()]) for field_name, data in columns.items(): output.add_column(field_name, data + [None] * (row_count - len(data))) print(output) - def print_results(self, result): + def print_results(self, result: sqlalchemy.CursorResult) -> None: self.print_table(list(result.keys()), [list(row) for row in result.all()]) - def ask_save(self): + def ask_save(self) -> str: + """ + Ask the user if they want to save. + :return: ``yes``, ``no`` or ``cancel``. + """ ask = AskSaveCmd() ask.cmdloop() return ask.result - def set_table_index(self, index) -> bool: + @abstractmethod + def set_prompt(self) -> None: + ... + + def set_table_index(self, index: int) -> bool: + """ + Move to a different table. + :param index: Index of the table to move to. + :return: True if there is a table with such an index to move to. + """ if 0 <= index and index < len(self._table_entries): self.table_index = index self.set_prompt() return True return False - def next_table(self, report="No more tables"): + def next_table(self, report: str="No more tables") -> bool: + """ + Move to the next table + :return: True if there is another table to move to. + """ if not self.set_table_index(self.table_index + 1): self.print(report) return False return True - def table_name(self): - return self._table_entries[self.table_index].name + def table_name(self) -> str: + """ Get the name of the current table. """ + return str(self._table_entries[self.table_index].name) def table_metadata(self) -> Table: return self.metadata.tables[self.table_name()] @@ -182,7 +207,7 @@ def table_metadata(self) -> Table: def get_column_names(self) -> list[str]: return [col.name for col in self.table_metadata().columns] - def report_columns(self): + def report_columns(self) -> None: self.print_table( ["name", "type", "primary", "nullable", "foreign key"], [ @@ -204,7 +229,7 @@ def get_table_config(self, table_name: str) -> dict[str, Any]: t = ts.get(table_name) return t if type(t) is dict else {} - def set_table_config(self, table_name: str, config: dict[str, Any]): + def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} @@ -228,7 +253,7 @@ def get_nonnull_columns(self, table_name: str) -> list[str]: if column.nullable ] - def find_entry_index_by_table_name(self, table_name) -> int | None: + def find_entry_index_by_table_name(self, table_name: str) -> int | None: return next( ( i @@ -238,13 +263,13 @@ def find_entry_index_by_table_name(self, table_name) -> int | None: None, ) - def find_entry_by_table_name(self, table_name) -> TableEntry | None: + def find_entry_by_table_name(self, table_name: str) -> TableEntry | None: for e in self._table_entries: if e.name == table_name: return e return None - def do_counts(self, _arg): + def do_counts(self, _arg: str) -> None: "Report the column names with the counts of nulls in them" if len(self._table_entries) <= self.table_index: return @@ -274,7 +299,7 @@ def do_counts(self, _arg): ], ) - def do_select(self, arg): + def do_select(self, arg: str) -> None: "Run a select query over the database and show the first 50 results" MAX_SELECT_ROWS = 50 with self.engine.connect() as connection: @@ -358,7 +383,7 @@ class TableCmd(DbCmd): NOTE_TEXT_NO_CHANGES = "You have made no changes." NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - def make_table_entry(self, name: str, table: Mapping) -> TableEntry: + def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | None: if table.get("ignore", False): return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) if table.get("vocabulary_table", False): @@ -370,14 +395,14 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry: return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @property def table_entries(self) -> list[TableCmdTableEntry]: - return cast(TableCmdTableEntry, self._table_entries) + return cast(list[TableCmdTableEntry], self._table_entries) def set_prompt(self) -> None: if self.table_index < len(self.table_entries): @@ -393,7 +418,6 @@ def set_type(self, t_type: TableType) -> None: def _copy_entries(self) -> None: for entry in self.table_entries: - entry: TableCmdTableEntry if entry.old_type != entry.new_type: table = self.get_table_config(entry.name) if ( @@ -654,7 +678,7 @@ def print_row_data(self, count: int) -> None: def update_config_tables( - src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping ) -> Mapping[str, Any]: with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() @@ -719,7 +743,7 @@ def find_missingness_query( return (src_stat.get("query", None), src_stat.get("comment", None)) return None - def make_table_entry(self, name: str, table: Mapping) -> TableEntry | None: + def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -760,7 +784,7 @@ def make_table_entry(self, name: str, table: Mapping) -> TableEntry | None: ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping ): """ Initialise a MissingnessCmd. @@ -774,7 +798,7 @@ def __init__( @property def table_entries(self) -> list[MissingnessCmdTableEntry]: - return cast(MissingnessCmdTableEntry, self._table_entries) + return cast(list[MissingnessCmdTableEntry], self._table_entries) def set_prompt(self) -> None: """ @@ -956,7 +980,7 @@ def do_none(self, _arg: str) -> None: def update_missingness( - src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() @@ -1018,7 +1042,7 @@ class GeneratorCmd(DbCmd): r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' ) - def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None: + def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -1028,7 +1052,8 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None metadata_table = self.metadata.tables[table_name] columns = [str(colname) for colname in metadata_table.columns.keys()] column_set = frozenset(columns) - columns_assigned_so_far = set() + columns_assigned_so_far: set[str] = set() + new_generator_infos: list[GeneratorInfo] = [] old_generator_infos: list[GeneratorInfo] = [] for rg in table.get("row_generators", []): @@ -1054,7 +1079,7 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None ) actual_collist = [c for c in collist if c in columns] if actual_collist: - gen = PredefinedGenerator(table, rg, self.config) + gen = PredefinedGenerator(table_name, rg, self.config) new_generator_infos.append( GeneratorInfo( columns=actual_collist.copy(), @@ -1085,7 +1110,7 @@ def make_table_entry(self, table_name: str, table: Mapping) -> TableEntry | None ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: Mapping[str, Any] + self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: """ Initialise a GeneratorCmd @@ -1096,12 +1121,12 @@ def __init__( """ super().__init__(src_dsn, src_schema, metadata, config) self.generator_index = 0 - self.generators_valid_columns = None + self.generators_valid_columns: Optional[tuple[int, list[str]]] = None self.set_prompt() @property def table_entries(self) -> list[GeneratorCmdTableEntry]: - return cast(GeneratorCmdTableEntry, self._table_entries) + return cast(list[GeneratorCmdTableEntry], self._table_entries) def set_table_index(self, index: int) -> bool: """ @@ -1245,7 +1270,7 @@ def _copy_entries(self) -> None: self.config["src-stats"] = src_stats def _find_old_generator( - self, entry: GeneratorCmdTableEntry, columns: Iterable[list] + self, entry: GeneratorCmdTableEntry, columns: Iterable[str] ) -> Generator | None: """Find any generator that previously assigned to these exact same columns.""" fc = frozenset(columns) @@ -1869,7 +1894,7 @@ def update_config_generators( src_dsn: str, src_schema: str, metadata: MetaData, - config: Mapping[str, Any], + config: MutableMapping[str, Any], spec_path: Path | None, ) -> Mapping[str, Any]: """ From ea02cf3a0dfd6260b1bfe651d9043dd84f4caf1c Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 11:22:38 +0100 Subject: [PATCH 09/35] mypy clean: dump, generators, interactive, providers --- datafaker/dump.py | 5 ++-- datafaker/generators.py | 34 +++++++++++++------------- datafaker/interactive.py | 53 ++++++++++++++++++++++++++++++---------- datafaker/providers.py | 2 +- 4 files changed, 61 insertions(+), 33 deletions(-) diff --git a/datafaker/dump.py b/datafaker/dump.py index 36ca046..5f81923 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,14 +1,15 @@ +from _csv import Writer import csv import io import sqlalchemy from sqlalchemy.schema import MetaData -from datafaker.settings import get_settings from datafaker.utils import create_db_engine, get_sync_engine, logger -def _make_csv_writer(file): +def _make_csv_writer(file: io.TextIOBase) -> Writer: + """Make the standard CSV file writer""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/generators.py b/datafaker/generators.py index 57a3c64..ff72858 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Any, Callable, Iterable, Sequence, TypeVar, Union +from typing import Any, Callable, Iterable, MutableSequence, Sequence, TypeVar, Union from typing_extensions import Self import mimesis @@ -249,7 +249,7 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -353,7 +353,7 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: return [ generator for factory in self.factories @@ -492,7 +492,7 @@ def __init__( self._end = end @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> list[Generator]: + def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> Sequence[Generator]: extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -592,7 +592,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -635,7 +635,7 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -656,7 +656,7 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -671,7 +671,7 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -688,7 +688,7 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -703,7 +703,7 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -821,13 +821,13 @@ def _get_generators_from_buckets( table_name: str, column_name: str, buckets: Buckets, - ) -> list[Generator]: + ) -> Sequence[Generator]: return [ GaussianGenerator(table_name, column_name, buckets), UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -923,7 +923,7 @@ def _get_generators_from_buckets( table_name: str, column_name: str, buckets: Buckets, - ) -> list[Generator]: + ) -> Sequence[Generator]: with engine.connect() as connection: result = connection.execute( text( @@ -1165,7 +1165,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1284,7 +1284,7 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1418,7 +1418,7 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1792,7 +1792,7 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> list[Generator]: + def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index c95dd92..d0355a1 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -13,12 +13,13 @@ import sqlalchemy from prettytable import PrettyTable -from sqlalchemy import Column, ForeignKey, MetaData, Table, text +from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table, text from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( T, create_db_engine, + get_sync_engine, fk_refers_to_ignored_table, logger, primary_private_fks, @@ -131,6 +132,10 @@ def __init__( self.table_index = 0 self.engine = create_db_engine(src_dsn, schema_name=src_schema) + @property + def sync_engine(self) -> Engine: + return get_sync_engine(self.engine) + def __enter__(self) -> Self: return self @@ -276,7 +281,7 @@ def do_counts(self, _arg: str) -> None: table_name = self.table_name() nonnull_columns = self.get_nonnull_columns(table_name) colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( "SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( @@ -302,7 +307,7 @@ def do_counts(self, _arg: str) -> None: def do_select(self, arg: str) -> None: "Run a select query over the database and show the first 50 results" MAX_SELECT_ROWS = 50 - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: try: result = connection.execute(text("SELECT " + arg)) except sqlalchemy.exc.DatabaseError as exc: @@ -330,7 +335,7 @@ def do_peek(self, arg: str) -> None: if not col_names: col_names = self.get_column_names() nonnulls = [cn + " IS NOT NULL" for cn in col_names] - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( cols=",".join(col_names), table=table_name, @@ -404,6 +409,12 @@ def __init__( def table_entries(self) -> list[TableCmdTableEntry]: return cast(list[TableCmdTableEntry], self._table_entries) + def find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + entry = super().find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(TableCmdTableEntry, entry) + def set_prompt(self) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] @@ -648,7 +659,7 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: column=column, len=min_length, ) - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( @@ -662,7 +673,7 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: self.columnize([str(x[0]) for x in result.all()]) def print_row_data(self, count: int) -> None: - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( @@ -715,7 +726,7 @@ def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> s @dataclass class MissingnessCmdTableEntry(TableEntry): old_type: MissingnessType - new_type: MissingnessType + new_type: MissingnessType | None class MissingnessCmd(DbCmd): @@ -731,7 +742,7 @@ class MissingnessCmd(DbCmd): def find_missingness_query( self, missingness_generator: Mapping - ) -> tuple[str | None, str | None] | None: + ) -> tuple[str, str] | None: """Find query and comment from src-stats for the passed missingness generator.""" kwargs = missingness_generator.get("kwargs", {}) patterns = kwargs.get("patterns", "") @@ -740,7 +751,10 @@ def find_missingness_query( key = pattern_match.group(1) for src_stat in self.config["src-stats"]: if src_stat.get("name") == key: - return (src_stat.get("query", None), src_stat.get("comment", None)) + query = src_stat.get("query", None) + if type(query) is not str: + return None + return (query, src_stat.get("comment", "")) return None def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntry | None: @@ -800,6 +814,12 @@ def __init__( def table_entries(self) -> list[MissingnessCmdTableEntry]: return cast(list[MissingnessCmdTableEntry], self._table_entries) + def find_entry_by_table_name(self, table_name: str) -> MissingnessCmdTableEntry | None: + entry = super().find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(MissingnessCmdTableEntry, entry) + def set_prompt(self) -> None: """ Sets the prompt according to the current table and missingness. @@ -814,7 +834,7 @@ def set_prompt(self) -> None: else: self.prompt = "(missingness) " - def set_type(self, t_type: TableType) -> None: + def set_type(self, t_type: MissingnessType) -> None: if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type @@ -1128,6 +1148,12 @@ def __init__( def table_entries(self) -> list[GeneratorCmdTableEntry]: return cast(list[GeneratorCmdTableEntry], self._table_entries) + def find_entry_by_table_name(self, table_name: str) -> GeneratorCmdTableEntry | None: + entry = super().find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(GeneratorCmdTableEntry, entry) + def set_table_index(self, index: int) -> bool: """ Moves to a new table. @@ -1239,7 +1265,7 @@ def _copy_entries(self) -> None: else [], } ) - rg = { + rg: dict[str, Any] = { "name": generator.gen.function_name(), "columns_assigned": generator.columns, } @@ -1431,6 +1457,7 @@ def _go_next(self) -> None: table = self.get_table() if table is None: self.print("No more tables") + return next_gi = self.generator_index + 1 if next_gi == len(table.new_generators): self.next_table(self.INFO_NO_MORE_TABLES) @@ -1502,7 +1529,7 @@ def _get_generator_proposals(self) -> list[Generator]: self.generators = None if self.generators is None: columns = self.column_metadata() - gens = everything_factory().get_generators(columns, self.engine) + gens = everything_factory().get_generators(columns, self.sync_engine) gens.sort(key=lambda g: g.fit(9999)) self.generators = gens self.generators_valid_columns = ( @@ -1666,7 +1693,7 @@ def _get_column_data(self, count: int, to_str: Callable[[Any], str]=repr) -> lis columns = self.get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: result = connection.execute( text( f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" diff --git a/datafaker/providers.py b/datafaker/providers.py index 1ebd5bf..9639f1b 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -29,7 +29,7 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed=None, **kwargs): + def __init__(self, *, seed: int | None=None, **kwargs: Any) -> None: super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} From 63f781e5c0f37c1c41b2852b7121041fcf9bf8da Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 12:42:22 +0100 Subject: [PATCH 10/35] Mypy fixed dump, interactive, main, serialize_metadata --- datafaker/dump.py | 6 +++-- datafaker/interactive.py | 18 +++++++------- datafaker/main.py | 36 +++++++++++++++------------- datafaker/serialize_metadata.py | 42 ++++++++++++++++++--------------- 4 files changed, 56 insertions(+), 46 deletions(-) diff --git a/datafaker/dump.py b/datafaker/dump.py index 5f81923..4c30911 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,14 +1,16 @@ -from _csv import Writer import csv import io +from typing import TYPE_CHECKING import sqlalchemy from sqlalchemy.schema import MetaData from datafaker.utils import create_db_engine, get_sync_engine, logger +if TYPE_CHECKING: + from _csv import Writer -def _make_csv_writer(file: io.TextIOBase) -> Writer: +def _make_csv_writer(file: io.TextIOBase) -> "Writer": """Make the standard CSV file writer""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index d0355a1..111a277 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -113,7 +113,7 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | Non ... def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ): super().__init__() self.config: MutableMapping[str, Any] = config @@ -400,7 +400,7 @@ def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | No return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -689,7 +689,7 @@ def print_row_data(self, count: int) -> None: def update_config_tables( - src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping ) -> Mapping[str, Any]: with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() @@ -798,7 +798,7 @@ def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntr ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping ): """ Initialise a MissingnessCmd. @@ -1000,7 +1000,7 @@ def do_none(self, _arg: str) -> None: def update_missingness( - src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() @@ -1130,7 +1130,7 @@ def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTable ) def __init__( - self, src_dsn: str, src_schema: str, metadata: MetaData, config: MutableMapping[str, Any] + self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] ) -> None: """ Initialise a GeneratorCmd @@ -1530,8 +1530,8 @@ def _get_generator_proposals(self) -> list[Generator]: if self.generators is None: columns = self.column_metadata() gens = everything_factory().get_generators(columns, self.sync_engine) - gens.sort(key=lambda g: g.fit(9999)) - self.generators = gens + sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) + self.generators = sorted_gens self.generators_valid_columns = ( self.table_index, self.get_column_names().copy(), @@ -1919,7 +1919,7 @@ def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int) -> def update_config_generators( src_dsn: str, - src_schema: str, + src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any], spec_path: Path | None, diff --git a/datafaker/main.py b/datafaker/main.py index c22f097..65b0102 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,11 +1,12 @@ """Entrypoint for the datafaker package.""" import asyncio +import io import json import sys from enum import Enum from importlib import metadata from pathlib import Path -from typing import Final, Optional +from typing import Any, Final, Optional import yaml from jsonschema.exceptions import ValidationError @@ -68,7 +69,7 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name, config: dict | None = None): +def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: with open(orm_file_name) as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) tables_dict = meta_dict.get("tables", {}) @@ -80,12 +81,12 @@ def load_metadata_config(orm_file_name, config: dict | None = None): return meta_dict -def load_metadata(orm_file_name, config: dict | None = None): +def load_metadata(orm_file_name: str, config: dict | None = None) -> Any: meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) -def load_metadata_for_output(orm_file_name, config: dict | None = None): +def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> Any: """ Load metadata excluding any foreign keys pointing to ignored tables. """ @@ -96,7 +97,7 @@ def load_metadata_for_output(orm_file_name, config: dict | None = None): @app.callback() def main( verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") -): +) -> None: conf_logger(verbose) @@ -202,7 +203,7 @@ def create_tables( def create_generators( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), df_file: str = Option(DF_FILENAME, help="Path to write Python generators to."), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: Optional[str] = Option( None, help=( @@ -348,11 +349,11 @@ def make_tables( @app.command() def configure_tables( - config_file: Optional[str] = Option( + config_file: str = Option( CONFIG_FILENAME, help="Path to write the configuration file to" ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), -): +) -> None: """ Interactively set tables to ignored, vocabulary or primary private. """ @@ -380,11 +381,11 @@ def configure_tables( @app.command() def configure_missing( - config_file: Optional[str] = Option( + config_file: str = Option( CONFIG_FILENAME, help="Path to write the configuration file to" ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), -): +) -> None: """ Interactively set the missingness of the generated data. """ @@ -392,11 +393,13 @@ def configure_missing( settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) config_file_path = Path(config_file) - config = {} + config: dict[str, Any] = {} if config_file_path.exists(): - config = yaml.load( + config_any = yaml.load( config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) + if type(config_any) is dict: + config = config_any metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) if config_updated is None: @@ -409,7 +412,7 @@ def configure_missing( @app.command() def configure_generators( - config_file: Optional[str] = Option( + config_file: str = Option( CONFIG_FILENAME, help="Path of the configuration file to alter" ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), @@ -417,7 +420,7 @@ def configure_generators( None, help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", ), -): +) -> None: """ Interactively set generators for column data. """ @@ -450,7 +453,7 @@ def dump_data( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), table: str = Argument(help="The table to dump"), output: str | None = Option(None, help="output CSV file name"), -): +) -> None: """Dump a whole table as a CSV file (or to the console) from the destination database.""" settings = get_settings() dst_dsn: str = settings.dst_dsn or "" @@ -459,7 +462,8 @@ def dump_data( config = read_config_file(config_file) if config_file is not None else {} metadata = load_metadata_for_output(orm_file, config) if output == None: - dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) + if isinstance(sys.stdout, io.TextIOBase): + dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return with open(output, "wt", newline="") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 303c2c7..936eb9f 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,17 +1,21 @@ -from typing import Callable +from typing import Callable, Protocol import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.sql import schema, sqltypes +import typing from datafaker.utils import make_foreign_key_name -table_component_t = dict[str, any] -table_t = dict[str, table_component_t] +table_t = dict[str, typing.Any] -def simple(type_): +# We will change this to parsy.Parser when parsy exports its types properly +ParserType = typing.Any + + +def simple(type_: type) -> ParserType: """ Parses a simple sqltypes type. For example, simple(sqltypes.UUID) takes the string "UUID" and outputs @@ -20,14 +24,14 @@ def simple(type_): return parsy.string(type_.__name__).result(type_) -def integer(): +def integer() -> ParserType: """ Parses an integer, outputting that integer. """ return parsy.regex(r"-?[0-9]+").map(int) -def integer_arguments(): +def integer_arguments() -> ParserType: """ Parses a list of integers. The integers are surrounded by brackets and separated by @@ -38,7 +42,7 @@ def integer_arguments(): ) -def numeric_type(type_): +def numeric_type(type_: type) -> ParserType: """ Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. @@ -48,9 +52,9 @@ def numeric_type(type_): ) -def string_type(type_): +def string_type(type_: type) -> ParserType: @parsy.generate(type_.__name__) - def st_parser(): + def st_parser() -> typing.Generator[ParserType, None, typing.Any]: """ Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" or TYPE_NAME(32) COLLATE "fr" @@ -67,9 +71,9 @@ def st_parser(): return st_parser -def time_type(type_, pg_type): +def time_type(type_: type, pg_type: type) -> ParserType: @parsy.generate(type_.__name__) - def pgt_parser(): + def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: """ Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE or TYPE_NAME(32) WITH TIME ZONE @@ -125,7 +129,7 @@ def pgt_parser(): @parsy.generate -def type_parser(): +def type_parser() -> ParserType: base = yield SIMPLE_TYPE_PARSER dimensions = yield parsy.string("[]").many().map(len) if dimensions == 0: @@ -133,7 +137,7 @@ def type_parser(): return postgresql.ARRAY(base, dimensions=dimensions) -def column_to_dict(column: Column, dialect: Dialect) -> str: +def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: type_ = column.type if isinstance(type_, postgresql.DOMAIN): # Instead of creating a restricted type, we'll just use the base type. @@ -156,8 +160,8 @@ def column_to_dict(column: Column, dialect: Dialect) -> str: def dict_to_column( - table_name, - col_name, + table_name: str, + col_name: str, rep: dict, ignore_fk: Callable[[str], bool], ) -> Column: @@ -236,7 +240,7 @@ def dict_to_table( def metadata_to_dict( meta: MetaData, schema_name: str | None, engine: Engine -) -> dict[str, table_t]: +) -> dict[str, typing.Any]: """ Converts a SQL Alchemy MetaData object into a Python object ready for conversion to YAML. @@ -251,7 +255,7 @@ def metadata_to_dict( } -def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): +def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]) -> bool: """ Tell if this foreign key should be ignored because it points to an ignored table. @@ -261,10 +265,10 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]): return True if fk_bits[0] not in tables_dict: return False - return tables_dict[fk_bits[0]].get("ignore", False) + return bool(tables_dict[fk_bits[0]].get("ignore", False)) -def dict_to_metadata(obj: dict, config_for_output: dict = None) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict | None=None) -> MetaData: """ Converts a dict to a SQL Alchemy MetaData object. From 3a5527b5007dd06ece3140c26e4dd97d15ec058b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 18:29:07 +0100 Subject: [PATCH 11/35] mypy clean in datafaker dir --- datafaker/make.py | 127 ++++++++++++++++++++++++++------------------- datafaker/utils.py | 100 +++++++++++++++++++++++++---------- 2 files changed, 149 insertions(+), 78 deletions(-) diff --git a/datafaker/make.py b/datafaker/make.py index 17e0d3b..67f42f6 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -5,7 +5,9 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple +from types import TracebackType +from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple, Type +from typing_extensions import Self import pandas as pd import snsql @@ -13,15 +15,17 @@ from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template from mimesis.providers.base import BaseProvider -from sqlalchemy import Engine, MetaData, UniqueConstraint, text +from sqlalchemy import CursorResult, Engine, MetaData, UniqueConstraint, text from sqlalchemy.dialects import postgresql -from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table -from sqlalchemy.sql import sqltypes, type_api +from sqlalchemy.sql import Executable, sqltypes, type_api from datafaker import providers from datafaker.settings import get_settings from datafaker.utils import ( + MaybeAsyncEngine, create_db_engine, download_table, get_flag, @@ -90,6 +94,18 @@ def make_column_choices( ] +@dataclass +class _PrimaryConstraint: + """ + Describes a Uniqueness constraint for when multiple + columns in a table comprise the primary key. Not a + real constraint, but enough to write df.py. + """ + + columns: list[Column] + name: str + + @dataclass class TableGeneratorInfo: """Contains the df.py content related to regular tables.""" @@ -100,7 +116,9 @@ class TableGeneratorInfo: column_choices: list[ColumnChoice] rows_per_pass: int row_gens: list[RowGeneratorInfo] = field(default_factory=list) - unique_constraints: list[UniqueConstraint] = field(default_factory=list) + unique_constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = field( + default_factory=list + ) @dataclass @@ -112,7 +130,7 @@ class StoryGeneratorInfo: num_stories_per_pass: int -def _render_value(v) -> str: +def _render_value(v: Any) -> str: if type(v) is list: return "[" + ", ".join(_render_value(x) for x in v) + "]" if type(v) is set: @@ -150,7 +168,7 @@ def _get_row_generator( ) -> tuple[list[RowGeneratorInfo], list[str]]: """Get the row generators information, for the given table.""" row_gen_info: list[RowGeneratorInfo] = [] - config: list[dict[str, Any]] = get_property(table_config, "row_generators", []) + config: list[Mapping[str, Any]] = get_property(table_config, "row_generators", []) columns_covered = [] for gen_conf in config: name: str = gen_conf["name"] @@ -220,14 +238,14 @@ def _get_default_generator(column: Column) -> RowGeneratorInfo: ( variable_names, generator_function, - generator_arguments, + generator_kwargs, ) = _get_provider_for_column(column) return RowGeneratorInfo( primary_key=column.primary_key, variable_names=variable_names, function_call=_get_function_call( - function_name=generator_function, keyword_arguments=generator_arguments + function_name=generator_function, keyword_arguments=generator_kwargs ), ) @@ -238,13 +256,14 @@ def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: that limit its range to the permitted scale. """ column_type = column.type - if column_type.scale is None: + scale = getattr(column_type, "scale", None) + if scale is None: return ("generic.numeric.float_number", {}) return ( "generic.numeric.float_number", { - "start": 0, - "end": 10**column_type.scale - 1, + "start": "0", + "end": str(10**scale - 1), }, ) @@ -284,7 +303,7 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: @dataclass class GeneratorInfo: # Name or function to generate random objects of this type (not using summary data) - generator: str | Callable[[Column], str] + generator: str | Callable[[Column], tuple[str, dict[str, str]]] # SQL query that gets the data to supply as arguments to the generator # ({column} and {table} will be interpolated) summary_query: str | None = None @@ -298,13 +317,18 @@ class GeneratorInfo: choice: bool = False -def get_result_mappings(info: GeneratorInfo, results) -> dict[str, Any]: +def get_result_mappings( + info: GeneratorInfo, results: CursorResult +) -> dict[str, Any] | None: """ Gets a mapping from the results of a database query as a Python dictionary converted according to the GeneratorInfo provided. """ - kw = {} - for k, v in results.mappings().first().items(): + kw: dict[str, Any] = {} + mapping = results.mappings().first() + if mapping is None: + return kw + for k, v in mapping.items(): if v is None: return None conv_fn = info.arg_types.get(k, float) @@ -374,7 +398,7 @@ def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: def _get_generator_for_column( column_t: type, -) -> str | Callable[[type_api.TypeEngine], tuple[str, dict[str, str]]]: +) -> str | Callable[[Column], tuple[str, dict[str, str]]] | None: """ Gets a generator from a column type. @@ -386,7 +410,7 @@ def _get_generator_for_column( return None if info is None else info.generator -def _get_generator_and_arguments(column: Column) -> tuple[str, dict[str, str]]: +def _get_generator_and_arguments(column: Column) -> tuple[str | None, dict[str, str]]: """ Gets the generator and its arguments from the column type, returning a tuple of a string representing the generator callable and a dict of @@ -442,18 +466,6 @@ def _constraint_sort_key(constraint: UniqueConstraint) -> str: ) -class _PrimaryConstraint: - """ - Describes a Uniqueness constraint for when multiple - columns in a table comprise the primary key. Not a - real constraint, but enough to write df.py. - """ - - def __init__(self, *columns: Column, name: str): - self.name = name - self.columns = columns - - def _get_generator_for_table( table_config: Mapping[str, Any], table: Table, @@ -468,10 +480,12 @@ def _get_generator_for_table( key=_constraint_sort_key, ) primary_keys = [c for c in table.columns if c.primary_key] + constraints: Sequence[UniqueConstraint | _PrimaryConstraint] = unique_constraints if 1 < len(primary_keys): - unique_constraints.append( - _PrimaryConstraint(*primary_keys, name=f"{table.name}_primary_key") + primary_constraint = _PrimaryConstraint( + columns=primary_keys, name=f"{table.name}_primary_key" ) + constraints = unique_constraints + [primary_constraint] column_choices = make_column_choices(table_config) if column_choices: nonnull_columns = { @@ -487,7 +501,7 @@ def _get_generator_for_table( nonnull_columns=nonnull_columns, column_choices=column_choices, rows_per_pass=get_property(table_config, "num_rows_per_pass", 1), - unique_constraints=unique_constraints, + unique_constraints=constraints, ) row_gen_info_data, columns_covered = _get_row_generator(table_config) @@ -525,7 +539,7 @@ def make_vocabulary_tables( overwrite_files: bool, compress: bool, table_names: set[str] | None = None, -): +) -> None: """ Extracts the data from the source database for each vocabulary table. @@ -660,8 +674,8 @@ def _generate_vocabulary_table( table: Table, engine: Engine, overwrite_files: bool = False, - compress=False, -): + compress: bool = False, +) -> None: """ Pulls data out of the source database to make a vocabulary YAML file """ @@ -712,33 +726,42 @@ def reflect_if(table_name: str, _: Any) -> bool: class DbConnection: - def __init__(self, engine): + def __init__(self, engine: MaybeAsyncEngine) -> None: + """ + Initialise an unopened database connection. + + Could be synchronous or asynchronous. + """ self._engine = engine + self._connection: Connection | AsyncConnection - async def __aenter__(self): + async def __aenter__(self) -> Self: if isinstance(self._engine, AsyncEngine): self._connection = await self._engine.connect() else: self._connection = self._engine.connect() return self - async def __aexit__(self, _type, _value, _tb): - if isinstance(self._engine, AsyncEngine): + async def __aexit__( + self, + _type: Optional[Type[BaseException]], + _value: Optional[BaseException], + _tb: Optional[TracebackType], + ) -> None: + if isinstance(self._connection, AsyncConnection): await self._connection.close() - else: - self._connection.close() + self._connection.close() - async def execute_raw_query(self, query): - if isinstance(self._engine, AsyncEngine): + async def execute_raw_query(self, query: Executable) -> CursorResult: + if isinstance(self._connection, AsyncConnection): return await self._connection.execute(query) - else: - return self._connection.execute(query) + return self._connection.execute(query) - async def table_row_count(self, table_name: str): + async def table_row_count(self, table_name: str) -> int: with await self.execute_raw_query( text(f"SELECT COUNT(*) FROM {table_name}") ) as result: - return result.scalar_one() + return int(result.scalar_one()) async def execute_query(self, query_block: Mapping[str, Any]) -> Any: """Execute query in query_block.""" @@ -766,19 +789,19 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: return final_result -def fix_type(value): +def fix_type(value: Any) -> Any: if type(value) is decimal.Decimal: return float(value) return value -def fix_types(dics): +def fix_types(dics: list[dict]) -> list[dict]: return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] async def make_src_stats( dsn: str, config: Mapping, metadata: MetaData, schema_name: Optional[str] = None -) -> dict[str, list[dict]]: +) -> dict[str, dict[str, Any]]: """Run the src-stats queries specified by the configuration. Query the src database with the queries in the src-stats block of the `config` @@ -801,7 +824,7 @@ async def make_src_stats( async def make_src_stats_connection( config: Mapping, db_conn: DbConnection, metadata: MetaData -): +) -> dict[str, dict[str, Any]]: date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/utils.py b/datafaker/utils.py index a94edf7..950f061 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -2,12 +2,23 @@ import ast import gzip import importlib.util +import io import json import logging import sys from pathlib import Path from types import ModuleType -from typing import Any, Final, Mapping, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + Final, + Generator, + Iterable, + Mapping, + Optional, + TypeVar, + Union, +) import sqlalchemy import yaml @@ -39,6 +50,7 @@ ) T = TypeVar("T") +_K = TypeVar("_K") def read_config_file(path: str) -> dict: @@ -76,16 +88,18 @@ def import_file(file_path: str) -> ModuleType: ModuleType """ spec = importlib.util.spec_from_file_location("df", file_path) + if spec is None or spec.loader is None: + raise Exception(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module -def open_file(file_name): +def open_file(file_name: str | Path) -> io.BufferedWriter: return Path(file_name).open("wb") -def open_compressed_file(file_name): +def open_compressed_file(file_name: str | Path) -> gzip.GzipFile: return gzip.GzipFile(file_name, "wb") @@ -211,14 +225,14 @@ class StdoutHandler(logging.Handler): We aren't using StreamHandler because that confuses typer.testing.CliRunner """ - def flush(self): + def flush(self) -> None: self.acquire() try: sys.stdout.flush() finally: self.release() - def emit(self, record): + def emit(self, record: Any) -> None: try: msg = self.format(record) sys.stdout.write(msg + "\n") @@ -235,14 +249,14 @@ class StderrHandler(logging.Handler): We aren't using StreamHandler because that confuses typer.testing.CliRunner """ - def flush(self): + def flush(self) -> None: self.acquire() try: sys.stderr.flush() finally: self.release() - def emit(self, record): + def emit(self, record: Any) -> None: try: msg = self.format(record) sys.stderr.write(msg + "\n") @@ -279,17 +293,17 @@ def conf_logger(verbose: bool) -> None: logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) -def get_flag(maybe_dict, key): +def get_flag(maybe_dict: Any, key: Any) -> Any: """Returns maybe_dict[key] or False if that doesn't exist""" return type(maybe_dict) is dict and maybe_dict.get(key, False) -def get_property(maybe_dict, key, default): +def get_property(maybe_dict: Mapping[_K, Any], key: _K, default: T) -> T: """Returns maybe_dict[key] or default if that doesn't exist""" return maybe_dict.get(key, default) if type(maybe_dict) is dict else default -def fk_refers_to_ignored_table(fk: ForeignKey): +def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: """ Does this foreign key refer to a table that is configured as ignore in config.yaml """ @@ -300,7 +314,7 @@ def fk_refers_to_ignored_table(fk: ForeignKey): return False -def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint): +def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: """ Does this foreign key constraint refer to a table that is configured as ignore in config.yaml """ @@ -331,7 +345,8 @@ def table_is_private(config: Mapping, table_name: str) -> bool: if type(ts) is not dict: return False t = ts.get(table_name, {}) - return t.get("primary_private", False) + ret = t.get("primary_private", False) + return ret if type(ret) is bool else False def primary_private_fks(config: Mapping, table: Table) -> list[str]: @@ -364,7 +379,11 @@ def make_foreign_key_name(table_name: str, col_name: str) -> str: return f"{table_name}_{col_name}_fkey" -def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): +def remove_vocab_foreign_key_constraints( + metadata: MetaData, + config: Mapping[str, Any], + dst_engine: Connection | Engine, +) -> None: vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] @@ -392,7 +411,20 @@ def remove_vocab_foreign_key_constraints(metadata, config, dst_engine): raise e -def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_engine): +def reinstate_vocab_foreign_key_constraints( + metadata: MetaData, + meta_dict: Mapping[str, Any], + config: Mapping[str, Any], + dst_engine: Connection | Engine, +) -> None: + """ + Put the removed foreign keys back into the destination database. + :param metadata: The SQLAlchemy metadata for the destination database. + :param meta_dict: The ``orm.yaml`` configuration that ``metadata`` was + created from. + :param config: The ``config.yaml`` data. + :param dst_engine: The connection to the destination database. + """ vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] @@ -419,7 +451,7 @@ def reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_eng ) -def stream_yaml(yaml_file_handle): +def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any]: """ Stream a yaml list into an iterator. @@ -441,23 +473,24 @@ def stream_yaml(yaml_file_handle): buf += line -def topological_sort(input_nodes, get_dependencies_fn): +def topological_sort( + input_nodes: Iterable[T], get_dependencies_fn: Callable[[T], set[T]] +) -> tuple[list[T], list[list[T]]]: """ Topoligically sort input_nodes and find any cycles. - Returns a pair (sorted, cycles). + Returns a pair ``(sorted, cycles)``. - 'sorted' is a list of all the elements of input_nodes sorted + ``sorted`` is a list of all the elements of input_nodes sorted so that dependencies returned by get_dependencies_fn come after nodes that depend on them. Cycles are arbitrarily broken for this. - 'cycles' is a list of lists of dependency cycles. + ``cycles`` is a list of lists of dependency cycles. - arguments: - input_nodes: an iterator of nodes to sort. Duplicates + :param input_nodes: an iterator of nodes to sort. Duplicates are discarded. - get_dependencies_fn: a function that takes an input + :param get_dependencies_fn: a function that takes an input node and returns a list of its dependencies. Any dependencies not in the input_nodes list are ignored. """ @@ -503,6 +536,21 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta return [metadata.tables[tn] for tn in sorted] +def underline_error(e: SyntaxError) -> str: + """ + Make an underline for this error. + :return: string beginning ``\n`` then spaces then ``^^^^`` + underlining the error, or a null string if this was not possible. + """ + start = e.offset + if start is None: + return "" + end = e.end_offset + if end is None or end <= start: + end = start + 1 + return "\n" + " " * start + "^" * (end - start) + + def generators_require_stats(config: Mapping) -> bool: """ Returns true if any of the arguments for any of the generators reference SRC_STATS. @@ -536,12 +584,12 @@ def generators_require_stats(config: Mapping) -> bool: except SyntaxError as e: errors.append( ( - "Syntax error in argument %d of %s: %s\n%s\n%s", + "Syntax error in argument %d of %s: %s\n%s%s", n + 1, where, e.msg, arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), + underline_error(e), ) ) for k, arg in call.get("kwargs", {}).items(): @@ -557,12 +605,12 @@ def generators_require_stats(config: Mapping) -> bool: except SyntaxError as e: errors.append( ( - "Syntax error in argument %s of %s: %s\n%s\n%s", + "Syntax error in argument %s of %s: %s\n%s%s", k, where, e.msg, arg, - " " * e.offset + "^" * max(1, e.end_offset - e.offset), + underline_error(e), ) ) for error in errors: From 3fffbadd926e2061ccc330081c4f90ff7f6ea586 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 18:29:34 +0100 Subject: [PATCH 12/35] pre-commit rewrites --- datafaker/base.py | 14 +++- datafaker/create.py | 4 +- datafaker/dump.py | 1 + datafaker/generators.py | 139 +++++++++++++++++++++----------- datafaker/interactive.py | 92 ++++++++++++++++----- datafaker/providers.py | 2 +- datafaker/serialize_metadata.py | 4 +- 7 files changed, 177 insertions(+), 79 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index acdc9b9..8270ccb 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -31,7 +31,9 @@ def zipf_weights(size: int) -> list[float]: return [1 / (n * total) for n in range(1, size + 1)] -def merge_with_constants(xs: list[T], constants_at: dict[int, T]) -> Generator[T, None, None]: +def merge_with_constants( + xs: list[T], constants_at: dict[int, T] +) -> Generator[T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. :param constants_at: A map of indices to objects that must be placed at @@ -88,7 +90,7 @@ def choice(self, a: list[T]) -> T: c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c - def zipf_choice(self, a: list[T], n: int | None=None) -> T: + def zipf_choice(self, a: list[T], n: int | None = None) -> T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] @@ -226,7 +228,9 @@ def _check_generator_name(self, name: str) -> None: raise Exception("%s is not a permitted generator", name) def alternatives( - self, alternative_configs: list[dict[str, Any]], counts: list[dict[str, int]] | None + self, + alternative_configs: list[dict[str, Any]], + counts: list[dict[str, int]] | None, ) -> Any: """ A generator that picks between other generators. @@ -270,7 +274,9 @@ def with_constants_at( logger.debug("Merging constants %s", constants_at) return list(merge_with_constants(subout, constants_at)) - def truncated_string(self, subgen_fn: Callable[..., list[T]], params: dict, length: int) -> list[T]: + def truncated_string( + self, subgen_fn: Callable[..., list[T]], params: dict, length: int + ) -> list[T]: """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" result = subgen_fn(**params) if result is None: diff --git a/datafaker/create.py b/datafaker/create.py index e902ec3..11b64a7 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -276,9 +276,7 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values( - table_generator(dst_conn) - ) + stmt = insert(table).values(table_generator(dst_conn)) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() diff --git a/datafaker/dump.py b/datafaker/dump.py index 4c30911..2ac187f 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from _csv import Writer + def _make_csv_writer(file: io.TextIOBase) -> "Writer": """Make the standard CSV file writer""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/generators.py b/datafaker/generators.py index ff72858..aa0b20b 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -12,13 +12,13 @@ from functools import lru_cache from itertools import chain, combinations from typing import Any, Callable, Iterable, MutableSequence, Sequence, TypeVar, Union -from typing_extensions import Self import mimesis import mimesis.locales import sqlalchemy from sqlalchemy import Column, Connection, CursorResult, Engine, RowMapping, text from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time, TypeEngine +from typing_extensions import Self from datafaker.base import DistributionGenerator from datafaker.utils import logger @@ -123,7 +123,7 @@ def generate_data(self, count: int) -> list[Any]: Generate 'count' random data points for this column. """ - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: """ Return a value representing how well the distribution fits the real source data. @@ -249,7 +249,9 @@ class GeneratorFactory(ABC): """ @abstractmethod - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: """ Returns all the generators that might be appropriate for this column. """ @@ -353,7 +355,9 @@ def __init__(self, factories: list[GeneratorFactory]): super().__init__() self.factories = factories - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: return [ generator for factory in self.factories @@ -427,7 +431,7 @@ def nominal_kwargs(self) -> dict[str, Any]: def actual_kwargs(self) -> dict[str, Any]: return {} - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default if self._fit is None else self._fit @@ -492,7 +496,9 @@ def __init__( self._end = end @classmethod - def make_singleton(_cls, column: Column, engine: Engine, function_name: str) -> Sequence[Generator]: + def make_singleton( + _cls, column: Column, engine: Engine, function_name: str + ) -> Sequence[Generator]: extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -592,7 +598,9 @@ class MimesisStringGeneratorFactory(GeneratorFactory): "text.word", ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -635,7 +643,9 @@ class MimesisFloatGeneratorFactory(GeneratorFactory): All Mimesis generators that return floating point numbers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -656,7 +666,9 @@ class MimesisDateGeneratorFactory(GeneratorFactory): All Mimesis generators that return dates. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -671,7 +683,9 @@ class MimesisDateTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return datetimes. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -688,7 +702,9 @@ class MimesisTimeGeneratorFactory(GeneratorFactory): All Mimesis generators that return times. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -703,7 +719,9 @@ class MimesisIntegerGeneratorFactory(GeneratorFactory): All Mimesis generators that return integers. """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -756,7 +774,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -827,7 +845,9 @@ def _get_generators_from_buckets( UniformGenerator(table_name, column_name, buckets), ] - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -906,7 +926,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) @@ -974,11 +994,11 @@ class ChoiceGenerator(Generator): def __init__( self, table_name: str, - column_name : str, + column_name: str, values: list[Any], counts: list[int], - sample_count: int | None=None, - suppress_count: int=0, + sample_count: int | None = None, + suppress_count: int = 0, ) -> None: super().__init__() self.table_name = table_name @@ -1045,7 +1065,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]: }, } - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default if self._fit is None else self._fit @@ -1080,6 +1100,7 @@ class UniformChoiceGenerator(ChoiceGenerator): """ A generator producing values, each roughly as frequently as each other. """ + def get_estimated_counts(self, counts: list[int]) -> list[int]: return list(uniform_distribution(sum(counts), len(counts))) @@ -1106,7 +1127,7 @@ def generate_data(self, count: int) -> list[Any]: class ValueGatherer: """ Gathers values from a query of values and counts. - + The query must return columns ``v`` for a value and ``f`` for the count of how many of those values there are. These values will be gathered into a number of properties: @@ -1120,15 +1141,12 @@ class ValueGatherer: :param suppress_count: value with a count of this or fewer will be excluded from the suppressed values. """ - def __init__(self, results: CursorResult, suppress_count: int=0) -> None: + + def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: values = [] # All values found counts = [] # The number or each value - cvs: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = ( - [] - ) # All values found more than SUPPRESS_COUNT times + cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times counts_not_suppressed = [] # The number for each value not suppressed cvs_not_suppressed: list[ dict[str, Any] @@ -1165,7 +1183,9 @@ class ChoiceGeneratorFactory(GeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1186,9 +1206,15 @@ def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Gene vg = ValueGatherer(results, self.SUPPRESS_COUNT) if vg.counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), - UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), - WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), ] results = connection.execute( text( @@ -1203,9 +1229,15 @@ def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Gene vg = ValueGatherer(results, self.SUPPRESS_COUNT) if vg.counts: generators += [ - ZipfChoiceGenerator(table_name, column_name, vg.values, vg.counts), - UniformChoiceGenerator(table_name, column_name, vg.values, vg.counts), - WeightedChoiceGenerator(table_name, column_name, vg.cvs, vg.counts), + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), ] generators += [ ZipfChoiceGenerator( @@ -1284,7 +1316,9 @@ class ConstantGeneratorFactory(GeneratorFactory): Just the null generator """ - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) != 1: return [] column = columns[0] @@ -1304,6 +1338,7 @@ class MultivariateNormalGenerator(Generator): """ Generator of multiple values drawn from a multivariate normal distribution. """ + def __init__( self, table_name: str, @@ -1350,7 +1385,7 @@ def generate_data(self, count: int) -> list[Any]: for _ in range(count) ] - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default @@ -1418,7 +1453,9 @@ def query( f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" ) - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1564,7 +1601,9 @@ def name(self) -> str: def function_name(self) -> str: return "dist_gen.alternatives" - def _nominal_kwargs_with_combinations(self, index: int, partition: RowPartition) -> dict[str, Any]: + def _nominal_kwargs_with_combinations( + self, index: int, partition: RowPartition + ) -> dict[str, Any]: count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' if not partition.included_numeric and not partition.included_choice: return { @@ -1621,7 +1660,9 @@ def custom_queries(self) -> dict[str, Any]: **partitions, } - def _actual_kwargs_with_combinations(self, partition: RowPartition) -> dict[str, Any]: + def _actual_kwargs_with_combinations( + self, partition: RowPartition + ) -> dict[str, Any]: count = sum(row["count"] for row in partition.covariates) if not partition.included_numeric and not partition.included_choice: return { @@ -1668,7 +1709,7 @@ def generate_data(self, count: int) -> list[Any]: kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] - def fit(self, default: float=-1) -> float: + def fit(self, default: float = -1) -> float: return default @@ -1737,12 +1778,14 @@ def __init__( class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 - EMPTY_RESULT = [RowMapping( - parent=sqlalchemy.engine.result.ResultMetaData(), - processors=None, - key_to_index={"count": 0}, - data=(0,) - )] + EMPTY_RESULT = [ + RowMapping( + parent=sqlalchemy.engine.result.ResultMetaData(), + processors=None, + key_to_index={"count": 0}, + data=(0,), + ) + ] def function_name(self) -> str: return "grouped_multivariate_normal" @@ -1792,7 +1835,9 @@ def get_partition_count_query( return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' - def get_generators(self, columns: list[Column], engine: Engine) -> Sequence[Generator]: + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 111a277..3238a38 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -9,18 +9,18 @@ from pathlib import Path from types import TracebackType from typing import Any, Callable, Iterable, Optional, Type, cast -from typing_extensions import Self import sqlalchemy from prettytable import PrettyTable from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table, text +from typing_extensions import Self from datafaker.generators import Generator, PredefinedGenerator, everything_factory from datafaker.utils import ( T, create_db_engine, - get_sync_engine, fk_refers_to_ignored_table, + get_sync_engine, logger, primary_private_fks, table_is_private, @@ -113,7 +113,11 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | Non ... def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ): super().__init__() self.config: MutableMapping[str, Any] = config @@ -192,7 +196,7 @@ def set_table_index(self, index: int) -> bool: return True return False - def next_table(self, report: str="No more tables") -> bool: + def next_table(self, report: str = "No more tables") -> bool: """ Move to the next table :return: True if there is another table to move to. @@ -203,7 +207,7 @@ def next_table(self, report: str="No more tables") -> bool: return True def table_name(self) -> str: - """ Get the name of the current table. """ + """Get the name of the current table.""" return str(self._table_entries[self.table_index].name) def table_metadata(self) -> Table: @@ -351,7 +355,9 @@ def do_peek(self, arg: str) -> None: rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] self.print_table(list(result.keys()), rows) - def complete_peek(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_peek( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: if len(self._table_entries) <= self.table_index: return [] return [ @@ -400,7 +406,11 @@ def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | No return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ) -> None: super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @@ -561,7 +571,9 @@ def do_next(self, arg: str) -> None: return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] @@ -645,7 +657,9 @@ def do_data(self, arg: str) -> None: number = 48 self.print_column_data(column, number, min_length) - def complete_data(self, text: str, line: str, begidx: int, _endidx: int) -> list[str]: + def complete_data( + self, text: str, line: str, begidx: int, _endidx: int + ) -> list[str]: previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] @@ -757,7 +771,9 @@ def find_missingness_query( return (query, src_stat.get("comment", "")) return None - def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntry | None: + def make_table_entry( + self, name: str, table: Mapping + ) -> MissingnessCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -798,7 +814,11 @@ def make_table_entry(self, name: str, table: Mapping) -> MissingnessCmdTableEntr ) def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping, ): """ Initialise a MissingnessCmd. @@ -814,7 +834,9 @@ def __init__( def table_entries(self) -> list[MissingnessCmdTableEntry]: return cast(list[MissingnessCmdTableEntry], self._table_entries) - def find_entry_by_table_name(self, table_name: str) -> MissingnessCmdTableEntry | None: + def find_entry_by_table_name( + self, table_name: str + ) -> MissingnessCmdTableEntry | None: entry = super().find_entry_by_table_name(table_name) if entry is None: return None @@ -927,7 +949,9 @@ def do_next(self, arg: str) -> None: return self.next_table(self.INFO_NO_MORE_TABLES) - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] @@ -1000,7 +1024,10 @@ def do_none(self, _arg: str) -> None: def update_missingness( - src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ) -> Mapping[str, Any]: with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() @@ -1012,6 +1039,7 @@ class GeneratorInfo: """ A generator and the columns it assigns to. """ + columns: list[str] gen: Generator | None @@ -1023,6 +1051,7 @@ class GeneratorCmdTableEntry(TableEntry): Includes the original setting and the currently configured generators. """ + old_generators: list[GeneratorInfo] new_generators: list[GeneratorInfo] @@ -1031,6 +1060,7 @@ class GeneratorCmd(DbCmd): """ Interactive command shell for setting generators. """ + intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the current column, then command 'compare' to see how these perform @@ -1062,7 +1092,9 @@ class GeneratorCmd(DbCmd): r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' ) - def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTableEntry | None: + def make_table_entry( + self, table_name: str, table: Mapping + ) -> GeneratorCmdTableEntry | None: if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -1130,7 +1162,11 @@ def make_table_entry(self, table_name: str, table: Mapping) -> GeneratorCmdTable ) def __init__( - self, src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping[str, Any] + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], ) -> None: """ Initialise a GeneratorCmd @@ -1148,7 +1184,9 @@ def __init__( def table_entries(self) -> list[GeneratorCmdTableEntry]: return cast(list[GeneratorCmdTableEntry], self._table_entries) - def find_entry_by_table_name(self, table_name: str) -> GeneratorCmdTableEntry | None: + def find_entry_by_table_name( + self, table_name: str + ) -> GeneratorCmdTableEntry | None: entry = super().find_entry_by_table_name(table_name) if entry is None: return None @@ -1465,7 +1503,9 @@ def _go_next(self) -> None: self.generator_index = next_gi self.set_prompt() - def complete_next(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: parts = text.split(".", 1) first_part = parts[0] if 1 < len(parts): @@ -1628,7 +1668,9 @@ def _print_custom_queries(self, gen: Generator) -> None: cq_key2args[cq_key], ) - def _get_custom_queries_from(self, out: dict[str, Any], nominal: Any, actual: Any) -> None: + def _get_custom_queries_from( + self, out: dict[str, Any], nominal: Any, actual: Any + ) -> None: if type(nominal) is str: src_stat_groups = self.SRC_STAT_RE.search(nominal) if src_stat_groups: @@ -1689,7 +1731,9 @@ def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None select_q = self._get_aggregate_query([gen], table_name) self.print("{0}; providing the following values: {1}", select_q, vals) - def _get_column_data(self, count: int, to_str: Callable[[Any], str]=repr) -> list[list[str]]: + def _get_column_data( + self, count: int, to_str: Callable[[Any], str] = repr + ) -> list[list[str]]: columns = self.get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) @@ -1852,7 +1896,9 @@ def do_merge(self, arg: str) -> None: table_entry.new_generators = new_new_generators self.set_prompt() - def complete_merge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_merge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: @@ -1905,7 +1951,9 @@ def do_unmerge(self, arg: str) -> None: ) self.set_prompt() - def complete_unmerge(self, text: str, _line: str, _begidx: int, _endidx: int) -> list[str]: + def complete_unmerge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: diff --git a/datafaker/providers.py b/datafaker/providers.py index 9639f1b..03e6cbe 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -29,7 +29,7 @@ def column_value( return getattr(random_row, column_name) return None - def __init__(self, *, seed: int | None=None, **kwargs: Any) -> None: + def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 936eb9f..d407d49 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,10 +1,10 @@ +import typing from typing import Callable, Protocol import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table from sqlalchemy.dialects import oracle, postgresql from sqlalchemy.sql import schema, sqltypes -import typing from datafaker.utils import make_foreign_key_name @@ -268,7 +268,7 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]) -> bool: return bool(tables_dict[fk_bits[0]].get("ignore", False)) -def dict_to_metadata(obj: dict, config_for_output: dict | None=None) -> MetaData: +def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaData: """ Converts a dict to a SQL Alchemy MetaData object. From 55acf142c01d7bf2c0002aba6e84065ac3514c54 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 18:43:30 +0100 Subject: [PATCH 13/35] test_dump is mypy clean --- tests/test_dump.py | 5 +++-- tests/utils.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_dump.py b/tests/test_dump.py index 2d5ed26..7033e18 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -1,4 +1,5 @@ """Tests for the base module.""" +import io from unittest.mock import MagicMock, call, patch from sqlalchemy.schema import MetaData @@ -17,9 +18,9 @@ class DumpTests(RequiresDBTestCase): @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: """Test dump-data.""" - TEST_OUTPUT_FILE = "test_output_file_object" + TEST_OUTPUT_FILE = io.StringIO() metadata = MetaData() - metadata.reflect(self.engine) + metadata.reflect(self.sync_engine) dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) make_csv_writer.assert_has_calls( diff --git a/tests/utils.py b/tests/utils.py index a6eb593..d74851e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,7 @@ from datafaker.create import create_db_data_into from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import create_db_engine, import_file, sorted_non_vocabulary_tables +from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables class SysExit(Exception): @@ -115,11 +115,11 @@ class RequiresDBTestCase(DatafakerTestCase): reflected from that engine. """ - schema_name = None + schema_name: str | None = None use_asyncio = False examples_dir = "tests/examples" - dump_file_path = None - database_name = None + dump_file_path: str | None = None + database_name: str | None = None Postgresql = None @classmethod @@ -140,6 +140,7 @@ def setUp(self) -> None: schema_name=self.schema_name, use_asyncio=self.use_asyncio, ) + self.sync_engine = get_sync_engine(self.engine) self.metadata = MetaData() self.metadata.reflect(self.engine) From c3709b8f6cd09700ac42fe0d1eead918a086d7bb Mon Sep 17 00:00:00 2001 From: Tim Band Date: Tue, 7 Oct 2025 19:10:12 +0100 Subject: [PATCH 14/35] Some mypy cleaning of tests directory --- datafaker/utils.py | 2 +- tests/test_base.py | 2 +- tests/test_create.py | 10 ++++----- tests/test_functional.py | 5 +++-- tests/test_make.py | 2 +- tests/test_providers.py | 4 ++-- tests/test_unique_generator.py | 9 ++++---- tests/test_utils.py | 14 ++++++------ tests/utils.py | 41 +++++++++++++++++++--------------- 9 files changed, 47 insertions(+), 42 deletions(-) diff --git a/datafaker/utils.py b/datafaker/utils.py index 950f061..b34664c 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -451,7 +451,7 @@ def reinstate_vocab_foreign_key_constraints( ) -def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any]: +def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any, None, None]: """ Stream a yaml list into an iterator. diff --git a/tests/test_base.py b/tests/test_base.py index 3f1e8cd..411f1c0 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -46,7 +46,7 @@ def test_load(self) -> None: """Test the load method.""" vocab_gen = FileUploader(BaseTable.__table__) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: vocab_gen.load(conn) statement = select(BaseTable) rows = list(conn.execute(statement)) diff --git a/tests/test_create.py b/tests/test_create.py index 333c01a..ca8c8f8 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -4,7 +4,7 @@ import random from collections import Counter from pathlib import Path -from typing import Any, Generator, Tuple +from typing import Any, Generator, Mapping, Tuple from unittest.mock import MagicMock, call, patch from sqlalchemy import Connection, select @@ -39,11 +39,11 @@ def test_create_vocab(self) -> None: }, } self.set_configuration(config) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) self.remove_data(config) remove_db_vocab(self.metadata, meta_dict, config) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = list(conn.execute(stmt).mappings().fetchall()) self.assertEqual(len(rows), 3) @@ -60,9 +60,9 @@ def test_create_vocab(self) -> None: def test_make_table_generators(self) -> None: """Test that we can handle column defaults in stories.""" random.seed(56) - config = {} + config: Mapping[str, Any] = {} self.generate_data(config, num_passes=2) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["string"]) rows = list(conn.execute(stmt).mappings().fetchall()) a = rows[0] diff --git a/tests/test_functional.py b/tests/test_functional.py index 418b4f9..05c3b89 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,9 +2,10 @@ import os import shutil from pathlib import Path +from typing import Any, Mapping from sqlalchemy import create_engine, inspect -from typer.testing import CliRunner +from typer.testing import CliRunner, Result from datafaker.main import app from tests.utils import RequiresDBTestCase @@ -431,7 +432,7 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args, expected_error: str = None, env={}): + def invoke(self, *args: Any, expected_error: str | None=None, env: Mapping[str, str]={}) -> Result: res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) diff --git a/tests/test_make.py b/tests/test_make.py index f43588a..b522778 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -45,7 +45,7 @@ def test_make_table_generators(self) -> None: }, } self.generate_data(config, num_passes=3) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = conn.execute(stmt).mappings().fetchall() for row in rows: diff --git a/tests/test_providers.py b/tests/test_providers.py index aedb693..b543783 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -49,7 +49,7 @@ def test_column_value_present(self) -> None: """Test the key method.""" # pylint: disable=invalid-name - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = insert(Person).values(sex="M") conn.execute(stmt) @@ -61,7 +61,7 @@ def test_column_value_present(self) -> None: def test_column_value_missing(self) -> None: """Test the generator when there are no values in the source table.""" - with self.engine.connect() as connection: + with self.sync_engine.connect() as connection: provider: providers.ColumnValueProvider = providers.ColumnValueProvider() generated_value: Any = provider.column_value(connection, Person, "sex") diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 81e9eea..c64e281 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -8,7 +8,6 @@ Integer, Text, UniqueConstraint, - create_engine, insert, ) from sqlalchemy.ext.declarative import declarative_base @@ -55,7 +54,7 @@ def test_unique_generator_empty_table(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: # Find a couple of different values that could be inserted, then try to do # one duplicate. test_ab1 = [True, False] @@ -83,7 +82,7 @@ def test_unique_generator_nonempty_table(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: test_ab1 = [True, False] test_ab2 = [False, False] string1 = "String 1" @@ -109,7 +108,7 @@ def test_unique_generator_multivalue_generator(self) -> None: uniq_ab = UniqueGenerator(["a", "b"], table_name) uniq_c = UniqueGenerator(["c"], table_name, max_tries=10) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: test_val1 = (True, False, "String 1") test_val2 = (True, False, "String 2") # Conflicts on (a, b) test_val3 = (False, False, "String 1") # Conflicts on c @@ -143,7 +142,7 @@ def test_unique_generator_max_tries(self) -> None: test_val = (True, False, "String 1") mock_generator.return_value = test_val - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: self.assertEqual(uniq_ab(conn, ["a", "b", "c"], mock_generator), test_val) self.assertRaises( RuntimeError, uniq_ab, conn, ["a", "b", "c"], mock_generator diff --git a/tests/test_utils.py b/tests/test_utils.py index 2640a9e..b34e922 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -81,12 +81,12 @@ def test_download_table(self) -> None: """Test the download_table function.""" # pylint: disable=protected-access - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(MyTable).values({"id": 1})) conn.commit() download_table( - MyTable.__table__, self.engine, self.mytable_file_path, compress=False + MyTable.__table__, self.sync_engine, self.mytable_file_path, compress=False ) # The .strip() gets rid of any possible empty lines at the end of the file. @@ -287,7 +287,7 @@ def test_generators_require_stats(self) -> None: ) @patch("datafaker.utils.logger") - def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): + def test_testing_generators_finds_syntax_errors(self, logger: MagicMock) -> None: generators_require_stats( { "story_generators": [ @@ -309,20 +309,20 @@ def test_testing_generators_finds_syntax_errors(self, logger: MagicMock): logger.error.assert_has_calls( [ call( - "Syntax error in argument %s of %s: %s\n%s\n%s", + "Syntax error in argument %s of %s: %s\n%s%s", "b", "story_generators[0]", "unterminated string literal (detected at line 1)", "'unclosed", - " ^", + "\n ^", ), call( - "Syntax error in argument %d of %s: %s\n%s\n%s", + "Syntax error in argument %d of %s: %s\n%s%s", 1, "tables.things.row_generators[0]", "invalid syntax", "1 2", - " ^", + "\n ^", ), ] ) diff --git a/tests/utils.py b/tests/utils.py index d74851e..a75b4f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ from pathlib import Path from subprocess import run from tempfile import mkstemp -from typing import Any +from typing import Any, Mapping from unittest import TestCase, skipUnless import testing.postgresql @@ -19,7 +19,7 @@ from datafaker.create import create_db_data_into from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables +from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables, T class SysExit(Exception): @@ -47,7 +47,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.maxDiff = None # pylint: disable=invalid-name super().__init__(*args, **kwargs) - def setUp(self): + def setUp(self) -> None: settings.get_settings.cache_clear() def assertReturnCode( # pylint: disable=invalid-name @@ -74,7 +74,7 @@ def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-nam return self.fail("".join(traceback.format_exception(result.exception))) - def assertSubset(self, set1, set2, msg=None): + def assertSubset(self, set1: set[T], set2: set[T], msg: str | None=None) -> None: """Assert a set is a (non-strict) subset. Args: @@ -117,21 +117,23 @@ class RequiresDBTestCase(DatafakerTestCase): schema_name: str | None = None use_asyncio = False - examples_dir = "tests/examples" + examples_dir = Path("tests/examples") dump_file_path: str | None = None database_name: str | None = None Postgresql = None @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.Postgresql = testing.postgresql.PostgresqlFactory(cache_initialized_db=True) @classmethod - def tearDownClass(cls): - cls.Postgresql.clear_cache() + def tearDownClass(cls) -> None: + if cls.Postgresql is not None: + cls.Postgresql.clear_cache() def setUp(self) -> None: super().setUp() + assert self.Postgresql is not None self.postgresql = self.Postgresql() if self.dump_file_path is not None: self.run_psql(Path(self.examples_dir) / Path(self.dump_file_path)) @@ -142,17 +144,20 @@ def setUp(self) -> None: ) self.sync_engine = get_sync_engine(self.engine) self.metadata = MetaData() - self.metadata.reflect(self.engine) + self.metadata.reflect(self.sync_engine) def tearDown(self) -> None: self.postgresql.stop() super().tearDown() @property - def dsn(self): + def dsn(self) -> str: if self.database_name: - return self.postgresql.url(database=self.database_name) - return self.postgresql.url() + url = self.postgresql.url(database=self.database_name) + else: + url = self.postgresql.url() + assert type(url) is str + return url def run_psql(self, dump_file: Path) -> None: """Run psql and pass dump_file_name as the --file option.""" @@ -187,7 +192,7 @@ def setUp(self) -> None: with os.fdopen(self.orm_fd, "w", encoding="utf-8") as orm_fh: orm_fh.write(make_tables_file(self.dsn, self.schema_name, {})) - def set_configuration(self, config) -> None: + def set_configuration(self, config: Mapping[str, Any]) -> None: """ Accepts a configuration file, writes it out. """ @@ -195,7 +200,7 @@ def set_configuration(self, config) -> None: with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) - def get_src_stats(self, config) -> dict[str, any]: + def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: """ Runs `make-stats` producing `src-stats.yaml` :return: Python dictionary representation of the contents of the src-stats file @@ -212,7 +217,7 @@ def get_src_stats(self, config) -> dict[str, any]: stats_fh.write(yaml.dump(src_stats)) return src_stats - def create_generators(self, config) -> None: + def create_generators(self, config: Mapping[str, Any]) -> None: """``create-generators`` with ``src-stats.yaml`` and the rest, producing ``df.py``""" datafaker_content = make_table_generators( self.metadata, @@ -225,12 +230,12 @@ def create_generators(self, config) -> None: with os.fdopen(generators_fd, "w", encoding="utf-8") as datafaker_fh: datafaker_fh.write(datafaker_content) - def remove_data(self, config): + def remove_data(self, config: Mapping[str, Any]) -> None: """Remove source data from the DB.""" # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) - def create_data(self, config, num_passes=1): + def create_data(self, config: Mapping[str, Any], num_passes: int=1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) @@ -245,7 +250,7 @@ def create_data(self, config, num_passes=1): self.schema_name, ) - def generate_data(self, config, num_passes=1): + def generate_data(self, config: Mapping[str, Any], num_passes: int=1) -> Mapping[str, Any]: """ Replaces the DB's source data with generated data. :return: A Python dictionary representation of the src-stats.yaml file, for what it's worth. From 113c4d203c60270e758f67d15a42991dfbfa7acf Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 8 Oct 2025 12:10:27 +0100 Subject: [PATCH 15/35] Much more cleaning. mypy clean --- .github/workflows/pre-commit.yml | 2 +- datafaker/interactive.py | 2 +- datafaker/make.py | 2 +- tests/test_create.py | 4 +- tests/test_functional.py | 4 +- tests/test_interactive.py | 253 ++++++++++++++++++------------- tests/test_remove.py | 47 +++--- tests/test_unique_generator.py | 9 +- tests/utils.py | 34 +++-- 9 files changed, 205 insertions(+), 152 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3e9d213..b07de4d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -9,7 +9,7 @@ on: env: # This should be the default but we'll be explicit PRE_COMMIT_HOME: ~/.caches/pre-commit - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" jobs: the_job: runs-on: ubuntu-latest diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 3238a38..47f918c 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -1969,7 +1969,7 @@ def update_config_generators( src_dsn: str, src_schema: str | None, metadata: MetaData, - config: MutableMapping[str, Any], + config: Mapping[str, Any], spec_path: Path | None, ) -> Mapping[str, Any]: """ diff --git a/datafaker/make.py b/datafaker/make.py index 67f42f6..ac7cdfc 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -7,7 +7,6 @@ from pathlib import Path from types import TracebackType from typing import Any, Callable, Final, Mapping, Optional, Sequence, Tuple, Type -from typing_extensions import Self import pandas as pd import snsql @@ -21,6 +20,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table from sqlalchemy.sql import Executable, sqltypes, type_api +from typing_extensions import Self from datafaker import providers from datafaker.settings import get_settings diff --git a/tests/test_create.py b/tests/test_create.py index ca8c8f8..b175f07 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -39,7 +39,9 @@ def test_create_vocab(self) -> None: }, } self.set_configuration(config) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) + meta_dict = metadata_to_dict( + self.metadata, self.schema_name, self.sync_engine + ) self.remove_data(config) remove_db_vocab(self.metadata, meta_dict, config) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) diff --git a/tests/test_functional.py b/tests/test_functional.py index 05c3b89..394691b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -432,7 +432,9 @@ def test_workflow_maximal_args(self) -> None: completed_process.stdout, ) - def invoke(self, *args: Any, expected_error: str | None=None, env: Mapping[str, str]={}) -> Result: + def invoke( + self, *args: Any, expected_error: str | None = None, env: Mapping[str, str] = {} + ) -> Result: res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 284e04d..e1de3b3 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -3,6 +3,7 @@ import random import re from dataclasses import dataclass +from typing import Any, Iterable, Mapping, MutableMapping from unittest.mock import MagicMock, Mock, patch from sqlalchemy import insert, select @@ -19,31 +20,39 @@ class TestDbCmdMixin(DbCmd): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a TestDbCmdMixin""" super().__init__(*args, **kwargs) self.reset() - def reset(self): - self.messages: list[tuple[str, list, dict[str, any]]] = [] + def reset(self) -> None: + """Reset all the debug messages collected so far.""" + self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] self.headings: list[str] = [] self.rows: list[list[str]] = [] - self.column_items: list[str] = [] - self.columns: dict[str, list[str]] = {} + self.column_items: list[list[str]] = [] + self.columns: dict[str, list[Any]] = {} - def print(self, text: str, *args, **kwargs): + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Capture the printed message.""" self.messages.append((text, args, kwargs)) - def print_table(self, headings: list[str], rows: list[list[str]]): + def print_table(self, headings: list[str], rows: list[list[str]]) -> None: + """Capture the printed table.""" self.headings = headings self.rows = rows - def print_table_by_columns(self, columns: dict[str, list[str]]): + def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: + """Capture the printed table.""" self.columns = columns - def columnize(self, items: list[str]): - self.column_items.append(items) + def columnize(self, items: list[str] | None, displaywidth: int = 80) -> None: + """Capture the printed table.""" + if items is not None: + self.column_items.append(items) def ask_save(self) -> str: + """Quitting always works without needing to ask the user.""" return "yes" @@ -54,7 +63,7 @@ class TestTableCmd(TableCmd, TestDbCmdMixin): class ConfigureTablesTests(RequiresDBTestCase): """Testing configure-tables.""" - def _get_cmd(self, config) -> TestTableCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) @@ -67,7 +76,7 @@ class ConfigureTablesSrcTests(ConfigureTablesTests): def test_table_name_prompts(self) -> None: """Test that the prompts follow the names of the tables.""" - config = {} + config: MutableMapping[str, Any] = {} with self._get_cmd(config) as tc: table_names = list(self.metadata.tables.keys()) for t in table_names: @@ -95,7 +104,7 @@ def test_table_name_prompts(self) -> None: def test_column_display(self) -> None: """Test that we can see the names of the columns.""" - config = {} + config: MutableMapping[str, Any] = {} with self._get_cmd(config) as tc: tc.do_next("unique_constraint_test") tc.do_columns("") @@ -211,7 +220,7 @@ def test_configure_tables(self) -> None: def test_print_data(self) -> None: """Test that we can print random rows from the table and random data from columns.""" person_table = self.metadata.tables["person"] - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: person_rows = conn.execute(select(person_table)).mappings().fetchall() person_data = {row["person_id"]: row for row in person_rows} name_set = {row["name"] for row in person_rows} @@ -255,7 +264,7 @@ def test_print_data(self) -> None: set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) ) - def test_list_tables(self): + def test_list_tables(self) -> None: """Test that we can list the tables""" config = { "tables": { @@ -308,7 +317,10 @@ class ConfigureTablesInstrumentsTests(ConfigureTablesTests): database_name = "instrument" schema_name = "public" - def test_sanity_checks_both(self): + def test_sanity_checks_both(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ config = { "tables": { "model": { @@ -349,7 +361,10 @@ def test_sanity_checks_both(self): ), ) - def test_sanity_checks_warnings_only(self): + def test_sanity_checks_warnings_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ config = { "tables": { "model": { @@ -388,7 +403,10 @@ def test_sanity_checks_warnings_only(self): ), ) - def test_sanity_checks_errors_only(self): + def test_sanity_checks_errors_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ config = { "tables": { "model": { @@ -431,7 +449,7 @@ def test_sanity_checks_errors_only(self): class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): """GeneratorCmd but mocked""" - def get_proposals(self) -> dict[str, tuple[int, str, str, list[str]]]: + def get_proposals(self) -> dict[str, tuple[int, str, list[str]]]: """ Returns a dict of generator name to a tuple of (index, fit_string, [list,of,samples]) """ @@ -449,10 +467,11 @@ class ConfigureGeneratorsTests(RequiresDBTestCase): database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the command we are using for this test case.""" return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_null_configuration(self): + def test_null_configuration(self) -> None: """Test that the tables having null configuration does not break.""" config = { "tables": None, @@ -466,7 +485,7 @@ def test_null_configuration(self): gc.do_quit("") self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) - def test_null_table_configuration(self): + def test_null_table_configuration(self) -> None: """Test that a table having null configuration does not break.""" config = { "tables": { @@ -483,7 +502,7 @@ def test_null_table_configuration(self): def test_prompts(self) -> None: """Test that the prompts follow the names of the columns and assigned generators.""" - config = {} + config: MutableMapping[str, Any] = {} with self._get_cmd(config) as gc: for table_name, table_meta in self.metadata.tables.items(): for column_name, column_meta in table_meta.columns.items(): @@ -521,7 +540,7 @@ def test_prompts(self) -> None: ) gc.reset() - def test_set_generator_mimesis(self): + def test_set_generator_mimesis(self) -> None: """Test that we can set one generator to a mimesis generator.""" with self._get_cmd({}) as gc: TABLE = "model" @@ -538,7 +557,7 @@ def test_set_generator_mimesis(self): {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, ) - def test_set_generator_distribution(self): + def test_set_generator_distribution(self) -> None: """Test that we can set one generator to gaussian.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -571,7 +590,7 @@ def test_set_generator_distribution(self): f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", ) - def test_set_generator_distribution_directly(self): + def test_set_generator_distribution_directly(self) -> None: """Test that we can set one generator to gaussian without going through propose.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -592,7 +611,7 @@ def test_set_generator_distribution_directly(self): f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", ) - def test_set_generator_choice(self): + def test_set_generator_choice(self) -> None: """Test that we can set one generator to uniform choice.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -626,7 +645,7 @@ def test_set_generator_choice(self): f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", ) - def test_weighted_choice_generator_generates_choices(self): + def test_weighted_choice_generator_generates_choices(self) -> None: """Test that propose and compare show weighted_choice's values.""" with self._get_cmd({}) as gc: TABLE = "string" @@ -643,7 +662,7 @@ def test_weighted_choice_generator_generates_choices(self): self.assertIn(col_heading, gc.columns) self.assertSubset(set(gc.columns[col_heading]), VALUES) - def test_merge_columns(self): + def test_merge_columns(self) -> None: """Test that we can merge columns and set a multivariate generator""" TABLE = "string" COLUMN_1 = "frequency" @@ -683,7 +702,7 @@ def test_merge_columns(self): self.assertEqual(row_gen["name"], GENERATOR) self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) - def test_unmerge_columns(self): + def test_unmerge_columns(self) -> None: """Test that we can unmerge columns and generators are removed""" TABLE = "string" COLUMN_1 = "frequency" @@ -722,7 +741,7 @@ def test_unmerge_columns(self): self.assertEqual(row_gen["name"], REMAINING_GEN) self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) - def test_old_generators_remain(self): + def test_old_generators_remain(self) -> None: """Test that we can set one generator and keep an old one.""" config = { "tables": { @@ -782,7 +801,7 @@ def test_old_generators_remain(self): "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", ) - def test_aggregate_queries_merge(self): + def test_aggregate_queries_merge(self) -> None: """ Test that we can set a generator that requires select aggregate clauses and keep an old one, resulting in a merged query. @@ -817,7 +836,7 @@ def test_aggregate_queries_merge(self): proposals = gc.get_proposals() gc.do_set(str(proposals[f"{GENERATOR}"][0])) gc.do_quit("") - row_gens: list[dict[str, any]] = gc.config["tables"]["string"][ + row_gens: list[dict[str, Any]] = gc.config["tables"]["string"][ "row_generators" ] self.assertEqual(len(row_gens), 2) @@ -850,9 +869,9 @@ def test_aggregate_queries_merge(self): select_match = re.match( r"SELECT (.*) FROM string", gc.config["src-stats"][0]["query"] ) - self.assertIsNotNone( - select_match, "src_stats[0].query is not an aggregate select" - ) + assert ( + select_match is not None + ), "src_stats[0].query is not an aggregate select" self.assertSetEqual( set(select_match.group(1).split(", ")), { @@ -863,7 +882,7 @@ def test_aggregate_queries_merge(self): }, ) - def test_next_completion(self): + def test_next_completion(self) -> None: """Test tab completion for the next command.""" with self._get_cmd({}) as gc: self.assertSetEqual( @@ -887,7 +906,7 @@ def test_next_completion(self): ) self.assertListEqual(gc.complete_next("ww", "next ww", 5, 7), []) - def test_compare_reports_privacy(self): + def test_compare_reports_privacy(self) -> None: """ Test that compare reports whether the current table is primary private, secondary private or not private. @@ -917,11 +936,11 @@ def test_compare_reports_privacy(self): self.assertEqual(text, gc.SECONDARY_PRIVATE_TEXT) self.assertSequenceEqual(args, [["model"]]) - def test_existing_configuration_remains(self): + def test_existing_configuration_remains(self) -> None: """ Test setting a generator does not remove other information. """ - config = { + config: MutableMapping[str, Any] = { "tables": { "string": { "primary_private": True, @@ -946,7 +965,7 @@ def test_existing_configuration_remains(self): self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) self.assertTrue(gc.config["tables"]["string"]["primary_private"]) - def test_empty_tables_are_not_configured(self): + def test_empty_tables_are_not_configured(self) -> None: """Test that tables marked as empty are not configured.""" config = { "tables": { @@ -969,10 +988,10 @@ class GeneratorsOutputTests(GeneratesDBTestCase): database_name = "numbers" schema_name = "public" - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_create_with_sampled_choice(self): + def test_create_with_sampled_choice(self) -> None: """Test that suppression works for choice and zipf_choice.""" table_name = "number_table" with self._get_cmd({}) as gc: @@ -1013,7 +1032,7 @@ def test_create_with_sampled_choice(self): gc.do_set(str(proposals["dist_gen.choice [sampled]"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -1028,7 +1047,7 @@ def test_create_with_sampled_choice(self): self.assertSetEqual(twos, {2, 3}) self.assertSetEqual(threes, {1, 2, 3, 4, 5}) - def test_create_with_choice(self): + def test_create_with_choice(self) -> None: """Smoke test normal choice works.""" table_name = "number_table" with self._get_cmd({}) as gc: @@ -1044,7 +1063,7 @@ def test_create_with_choice(self): gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -1056,7 +1075,7 @@ def test_create_with_choice(self): self.assertSetEqual(ones, {1, 2, 3, 4, 5}) self.assertSetEqual(twos, {1, 2, 3, 4, 5}) - def test_create_with_weighted_choice(self): + def test_create_with_weighted_choice(self) -> None: """Smoke test weighted choice.""" table_name = "number_table" with self._get_cmd({}) as gc: @@ -1077,7 +1096,8 @@ def test_create_with_weighted_choice(self): f"{prop[0]}. dist_gen.weighted_choice [sampled and suppressed]" ) self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 4}) + col_set: set[int] = set(gc.columns[col_heading]) + self.assertSubset(col_set, {1, 4}) gc.do_set(str(prop[0])) gc.do_next("number_table.two") gc.reset() @@ -1094,7 +1114,8 @@ def test_create_with_weighted_choice(self): gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice" self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 2, 3, 4, 5}) + col_set2: set[int] = set(gc.columns[col_heading]) + self.assertSubset(col_set2, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_next("number_table.three") gc.reset() @@ -1110,11 +1131,12 @@ def test_create_with_weighted_choice(self): gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled]" self.assertIn(col_heading, set(gc.columns.keys())) - self.assertSubset(set(gc.columns[col_heading]), {1, 2, 3, 4, 5}) + col_set3: set[int] = set(gc.columns[col_heading]) + self.assertSubset(col_set3, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -1141,62 +1163,60 @@ class ConfigureMissingnessTests(RequiresDBTestCase): database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestMissingnessCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + """We are using configure-missingness.""" return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - def test_set_missingness_to_sampled(self): + def test_set_missingness_to_sampled(self) -> None: """Test that we can set one table to sampled missingness.""" with self._get_cmd({}) as mc: TABLE = "signature_model" mc.do_next(TABLE) mc.do_counts("") self.assertListEqual( - mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (6,), {})] + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] ) - self.assertListEqual(mc.rows, [["player_id", 3], ["based_on", 2]]) + # Check the counts of NULLs in each column + self.assertListEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) mc.do_sampled("") mc.do_quit("") - self.assertDictEqual( - mc.config, - { - "tables": { - TABLE: { - "missingness_generators": [ - { - "columns": ["player_id", "based_on"], - "kwargs": { - "patterns": 'SRC_STATS["missing_auto__signature_model__0"]' - }, - "name": "column_presence.sampled", - } - ] - } - }, - "src-stats": [ - { - "name": "missing_auto__signature_model__0", - "query": ( - "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" - ), - } - ], - }, + self.assertListEqual( + mc.config["tables"][TABLE]["missingness_generators"], + [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' + }, + "name": "column_presence.sampled", + } + ], + ) + self.assertEqual( + mc.config["src-stats"][0]["name"], + "missing_auto__signature_model__0", + ) + self.assertEqual( + mc.config["src-stats"][0]["query"], + ( + "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" + ), ) -class ConfigureMissingnessTests(GeneratesDBTestCase): +class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): """Testing configure-missing with generation.""" dump_file_path = "instrument.sql" database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestMissingnessCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - def test_create_with_missingness(self): + def test_create_with_missingness(self) -> None: """Test that we can sample real missingness and reproduce it.""" random.seed(45) # Configure the missingness @@ -1208,7 +1228,7 @@ def test_create_with_missingness(self): config = mc.config self.generate_data(config, num_passes=100) # Test that each missingness pattern is present in the database - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).mappings().fetchall() patterns: set[int] = set() @@ -1227,10 +1247,11 @@ class GeneratorTests(GeneratesDBTestCase): database_name = "instrument" schema_name = "public" - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """We are using configure-generators.""" return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_set_null(self): + def test_set_null(self) -> None: """Test that we can sample real missingness and reproduce it.""" with self._get_cmd({}) as gc: gc.do_next("string.position") @@ -1256,8 +1277,13 @@ def test_set_null(self): config = gc.config self.generate_data(config, num_passes=3) # Test that each missingness pattern is present in the database - with self.engine.connect() as conn: - stmt = select(self.metadata.tables["string"].c["position", "frequency"]) + with self.sync_engine.connect() as conn: + # select(self.metadata.tables["string"].c["position", "frequency"]) would be nicer + # but mypy doesn't like it + stmt = select( + self.metadata.tables["string"].c["position"], + self.metadata.tables["string"].c["frequency"], + ) rows = conn.execute(stmt).fetchall() count = 0 for row in rows: @@ -1265,7 +1291,12 @@ def test_set_null(self): self.assertEqual(row.position, 0) self.assertEqual(row.frequency, 0.0) self.assertEqual(count, 3) - stmt = select(self.metadata.tables["signature_model"].c["name", "based_on"]) + # select(self.metadata.tables["signature_model"].c["name", "based_on"]) would be nicer + # but mypy doesn't like it + stmt = select( + self.metadata.tables["signature_model"].c["name"], + self.metadata.tables["signature_model"].c["based_on"], + ) rows = conn.execute(stmt).fetchall() count = 0 for row in rows: @@ -1274,7 +1305,7 @@ def test_set_null(self): self.assertIsNone(row.based_on) self.assertEqual(count, 3) - def test_dist_gen_sampled_produces_ordered_src_stats(self): + def test_dist_gen_sampled_produces_ordered_src_stats(self) -> None: """Tests that choosing a sampled choice generator produces ordered src stats""" with self._get_cmd({}) as gc: gc.do_next("signature_model.player_id") @@ -1294,7 +1325,11 @@ def test_dist_gen_sampled_produces_ordered_src_stats(self): ] self.assertListEqual(based_ons, [1, 3, 2]) - def assertAreTruncatedTo(self, xs, length): + def assertAreTruncatedTo(self, xs: Iterable[str], length: int) -> None: + """ + Check that none of the strings are longer than ``length`` (after + removing surrounding quotes). + """ maxlen = 0 for x in xs: newlen = len(x.strip("'\"")) @@ -1302,7 +1337,7 @@ def assertAreTruncatedTo(self, xs, length): maxlen = max(maxlen, newlen) self.assertEqual(maxlen, length) - def test_varchar_ns_are_truncated(self): + def test_varchar_ns_are_truncated(self) -> None: """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" GENERATOR = "generic.text.quote" TABLE = "signature_model" @@ -1325,7 +1360,7 @@ def test_varchar_ns_are_truncated(self): gc.do_quit("") config = gc.config self.generate_data(config, num_passes=15) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[TABLE].c[COLUMN]) rows = conn.execute(stmt).scalars().fetchall() self.assertAreTruncatedTo(rows, 20) @@ -1359,7 +1394,7 @@ class Correlation(Stat): y2: float = 0 xy: float = 0 - def add(self, x: float, y: float) -> None: + def add2(self, x: float, y: float) -> None: self.n += 1 self.x += x self.x2 += x * x @@ -1386,14 +1421,16 @@ class NullPartitionedTests(GeneratesDBTestCase): schema_name = "public" def setUp(self) -> None: + """Set up the test with specific sample and suppress counts.""" super().setUp() NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 - def _get_cmd(self, config) -> TestGeneratorCmd: + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the configure-generators object as our command.""" return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - def test_create_with_null_partitioned_grouped_multivariate(self): + def test_create_with_null_partitioned_grouped_multivariate(self) -> None: """Test EAV for all columns.""" table_name = "measurement" generate_count = 800 @@ -1422,7 +1459,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.remove_data(gc.config) # let's add a vocab table without messing around with files table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(table).values({"id": 1, "name": "agreement"})) conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) conn.execute(insert(table).values({"id": 3, "name": "velocity"})) @@ -1430,7 +1467,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self): conn.execute(insert(table).values({"id": 5, "name": "matter"})) conn.commit() self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() one_count = 0 @@ -1454,19 +1491,19 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertIsNotNone(row.first_value) self.assertIsNotNone(row.second_value) self.assertIsNone(row.third_value) - two.add(row.first_value, row.second_value) + two.add2(row.first_value, row.second_value) elif row.type == 3: # negative correlation around 11.8, 12.1 self.assertIsNotNone(row.first_value) self.assertIsNotNone(row.second_value) self.assertIsNone(row.third_value) - three.add(row.first_value, row.second_value) + three.add2(row.first_value, row.second_value) elif row.type == 4: # positive correlation around 21.4, 23.4 self.assertIsNotNone(row.first_value) self.assertIsNotNone(row.second_value) self.assertIsNone(row.third_value) - four.add(row.first_value, row.second_value) + four.add2(row.first_value, row.second_value) elif row.type == 5: self.assertIn(row.third_value, {"fish", "fowl"}) self.assertIsNotNone(row.first_value) @@ -1517,7 +1554,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self): self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) - def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): + def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: """Test EAV for all columns with sampled and suppressed generation.""" table_name = "measurement" table2_name = "observation" @@ -1566,7 +1603,7 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): self.remove_data(gc.config) # let's add a vocab table without messing around with files table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: conn.execute(insert(table).values({"id": 1, "name": "agreement"})) conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) conn.execute(insert(table).values({"id": 3, "name": "velocity"})) @@ -1574,7 +1611,7 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self): conn.execute(insert(table).values({"id": 5, "name": "matter"})) conn.commit() self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() one_count = 0 @@ -1661,11 +1698,11 @@ class NonInteractiveTests(RequiresDBTestCase): ) def test_non_interactive_configure_generators( self, mock_csv_reader: MagicMock, mock_path: MagicMock - ): + ) -> None: """ test that we can set generators from a CSV file """ - config = {} + config: Mapping[str, Any] = {} spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( self.dsn, self.schema_name, self.metadata, config, spec_csv diff --git a/tests/test_remove.py b/tests/test_remove.py index bfbb787..24286fb 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch from sqlalchemy import func, inspect, select +from sqlalchemy.engine import Connection from datafaker.remove import remove_db_data, remove_db_tables, remove_db_vocab from datafaker.serialize_metadata import metadata_to_dict @@ -16,17 +17,16 @@ class RemoveThingsTestCase(RequiresDBTestCase): database_name = "instrument" schema_name = "public" - def count_rows(self, connection, table_name: str) -> int | None: + def count_rows(self, connection: Connection, table_name: str) -> int | None: return connection.execute( select(func.count()).select_from(self.metadata.tables[table_name]) ).scalar() @patch("datafaker.remove.get_settings") - def test_remove_data(self, mock_get_settings: MagicMock): + def test_remove_data(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) remove_db_data( self.metadata, @@ -37,9 +37,9 @@ def test_remove_data(self, mock_get_settings: MagicMock): } }, ) - with self.engine.connect() as conn: - self.assertGreater(self.count_rows(conn, "manufacturer"), 0) - self.assertGreater(self.count_rows(conn, "model"), 0) + with self.sync_engine.connect() as conn: + self.assertGreaterAndNotNone(self.count_rows(conn, "manufacturer"), 0) + self.assertGreaterAndNotNone(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) self.assertEqual(self.count_rows(conn, "string"), 0) self.assertEqual(self.count_rows(conn, "signature_model"), 0) @@ -50,7 +50,6 @@ def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: remove_db_data( @@ -67,13 +66,12 @@ def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: ) @patch("datafaker.remove.get_settings") - def test_remove_vocab(self, mock_get_settings: MagicMock): + def test_remove_vocab(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.sync_engine) config = { "tables": { "manufacturer": {"vocabulary_table": True}, @@ -82,7 +80,7 @@ def test_remove_vocab(self, mock_get_settings: MagicMock): } remove_db_data(self.metadata, config) remove_db_vocab(self.metadata, meta_dict, config) - with self.engine.connect() as conn: + with self.sync_engine.connect() as conn: self.assertEqual(self.count_rows(conn, "manufacturer"), 0) self.assertEqual(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) @@ -95,10 +93,11 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: - meta_dict = metadata_to_dict(self.metadata, self.schema_name, self.engine) + meta_dict = metadata_to_dict( + self.metadata, self.schema_name, self.sync_engine + ) remove_db_vocab( self.metadata, meta_dict, @@ -114,19 +113,24 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: ) @patch("datafaker.remove.get_settings") - def test_remove_tables(self, mock_get_settings: MagicMock): + def test_remove_tables(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, - _env_file=None, ) - self.assertTrue(inspect(self.engine).has_table("player")) + engine_in = inspect(self.engine) + assert engine_in is not None + assert hasattr(engine_in, "has_table") + self.assertTrue(engine_in.has_table("player")) remove_db_tables(self.metadata) - self.assertFalse(inspect(self.engine).has_table("manufacturer")) - self.assertFalse(inspect(self.engine).has_table("model")) - self.assertFalse(inspect(self.engine).has_table("player")) - self.assertFalse(inspect(self.engine).has_table("string")) - self.assertFalse(inspect(self.engine).has_table("signature_model")) + engine_out = inspect(self.engine) + assert engine_out is not None + assert hasattr(engine_out, "has_table") + self.assertFalse(engine_out.has_table("manufacturer")) + self.assertFalse(engine_out.has_table("model")) + self.assertFalse(engine_out.has_table("player")) + self.assertFalse(engine_out.has_table("string")) + self.assertFalse(engine_out.has_table("signature_model")) @patch("datafaker.remove.get_settings") def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: @@ -134,7 +138,6 @@ def test_remove_tables_raises(self, mock_get_settings: MagicMock) -> None: mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=None, - _env_file=None, ) with self.assertRaises(AssertionError) as context_manager: remove_db_tables(self.metadata) diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index c64e281..503a36f 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -2,14 +2,7 @@ from pathlib import Path from unittest.mock import MagicMock -from sqlalchemy import ( - Boolean, - Column, - Integer, - Text, - UniqueConstraint, - insert, -) +from sqlalchemy import Boolean, Column, Integer, Text, UniqueConstraint, insert from sqlalchemy.ext.declarative import declarative_base from datafaker.unique_generator import UniqueGenerator diff --git a/tests/utils.py b/tests/utils.py index a75b4f0..4e9f236 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,13 @@ from datafaker.create import create_db_data_into from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from -from datafaker.utils import create_db_engine, get_sync_engine, import_file, sorted_non_vocabulary_tables, T +from datafaker.utils import ( + T, + create_db_engine, + get_sync_engine, + import_file, + sorted_non_vocabulary_tables, +) class SysExit(Exception): @@ -74,14 +80,22 @@ def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-nam return self.fail("".join(traceback.format_exception(result.exception))) - def assertSubset(self, set1: set[T], set2: set[T], msg: str | None=None) -> None: + def assertGreaterAndNotNone(self, left: float | None, right: float) -> None: + """ + Assert left is not None and greater than right + """ + if left is None: + self.fail("first argument is None") + else: + self.assertGreater(left, right) + + def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: """Assert a set is a (non-strict) subset. - Args: - set1: The asserted subset. - set2: The asserted superset. - msg: Optional message to use on failure instead of a list of - differences. + :param set1: The asserted subset. + :param set2: The asserted superset. + :param msg: Optional message to use on failure instead of a list of + differences. """ try: difference = set1.difference(set2) @@ -235,7 +249,7 @@ def remove_data(self, config: Mapping[str, Any]) -> None: # `remove-data` so we don't have to use a separate database for the destination remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) - def create_data(self, config: Mapping[str, Any], num_passes: int=1) -> None: + def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) @@ -250,7 +264,9 @@ def create_data(self, config: Mapping[str, Any], num_passes: int=1) -> None: self.schema_name, ) - def generate_data(self, config: Mapping[str, Any], num_passes: int=1) -> Mapping[str, Any]: + def generate_data( + self, config: Mapping[str, Any], num_passes: int = 1 + ) -> Mapping[str, Any]: """ Replaces the DB's source data with generated data. :return: A Python dictionary representation of the src-stats.yaml file, for what it's worth. From 10b02c5f243964345d66181d8e87e5db568fc513 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 8 Oct 2025 18:14:40 +0100 Subject: [PATCH 16/35] precommit cleanup, NullPartitionedGrouped fix --- .github/workflows/tests.yml | 2 +- .pylintrc | 2 +- .readthedocs.yaml | 2 +- datafaker/create.py | 39 +++-- datafaker/dump.py | 3 +- datafaker/generators.py | 27 ++- datafaker/interactive.py | 36 ++-- datafaker/main.py | 25 ++- datafaker/make.py | 8 +- datafaker/providers.py | 1 + datafaker/serialize_metadata.py | 83 ++++++--- mypy.ini | 2 +- tests/test_functional.py | 29 +++- tests/test_interactive.py | 295 +++++++++++++++++--------------- tests/test_utils.py | 1 + 15 files changed, 334 insertions(+), 221 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 137557a..75f45f0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: - main env: # This should be the default but we'll be explicit - PYTHON_VERSION: "3.9" + PYTHON_VERSION: "3.10" jobs: the_job: runs-on: ubuntu-latest diff --git a/.pylintrc b/.pylintrc index d97276b..22a92bd 100644 --- a/.pylintrc +++ b/.pylintrc @@ -53,7 +53,7 @@ persistent=yes # Min Python version to use for version dependend checks. Will default to the # version used to run pylint. -py-version=3.9 +py-version=3.10 # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 91942b1..29cdf78 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # You can also specify other tool versions: # nodejs: "19" # rust: "1.64" diff --git a/datafaker/create.py b/datafaker/create.py index 11b64a7..f11a0dd 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,6 +1,5 @@ """Functions and classes to create and populate the target database.""" import pathlib -import random from collections import Counter from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple @@ -125,6 +124,19 @@ def create_db_data_into( db_dsn: str, schema_name: str | None, ) -> RowCounts: + """ + Populate the database. + + :param sorted_tables: The table names to populate, sorted so that foreign + keys' targets are populated before the foreign keys themselves. + :param table_generator_dict: A mapping of table names to the generators + used to make data for them. + :param story_generator_list: A list of story generators to be run after the + table generators on each pass. + :param num_passes: Number of passes to perform. + :param db_dsn: Connection string for the destination database. + :param schema_name: Destination schema name. + """ dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) row_counts: Counter[str] = Counter() @@ -140,6 +152,8 @@ def create_db_data_into( class StoryIterator: + """Iterates through all the rows produced by all the stories.""" + def __init__( self, stories: Iterable[tuple[str, Story]], @@ -147,6 +161,7 @@ def __init__( table_generator_dict: Mapping[str, TableGenerator], dst_conn: Connection, ): + """Initialise a Story Iterator.""" self._stories: Iterator[tuple[str, Story]] = iter(stories) self._table_dict: Mapping[str, Table] = table_dict self._table_generator_dict: Mapping[str, TableGenerator] = table_generator_dict @@ -162,27 +177,31 @@ def __init__( def is_ended(self) -> bool: """ - Do we have another row to process? + Check if we have another row to process. + If so, insert() can be called. """ return self._table_name is None def has_table(self, table_name: str) -> bool: - """ - Do we have a row for table table_name? - """ + """Check if we have a row for table ``table_name``.""" return table_name == self._table_name def table_name(self) -> str | None: """ - The name of the current table (or None if no more stories to process) + Get the name of the current table. + + :return: The table name, or None if there are no more stories + to process. """ return self._table_name def insert(self) -> None: """ - Perform the insert. Call this after __init__ or next, and after checking - that is_ended returns False. + Put the row in the table. + + Call this after __init__ or next, and after checking that is_ended + returns False. """ if self._table_name is None: raise StopIteration("StoryIterator.insert after is_ended") @@ -210,9 +229,7 @@ def insert(self) -> None: cursor.close() def next(self) -> None: - """ - Advance to the next table row. - """ + """Advance to the next row.""" while True: try: if self._final_values is None: diff --git a/datafaker/dump.py b/datafaker/dump.py index 2ac187f..95f251a 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,3 +1,4 @@ +""" Data dumping functions. """ import csv import io from typing import TYPE_CHECKING @@ -12,7 +13,7 @@ def _make_csv_writer(file: io.TextIOBase) -> "Writer": - """Make the standard CSV file writer""" + """Make the standard CSV file writer.""" return csv.writer(file, quoting=csv.QUOTE_MINIMAL) diff --git a/datafaker/generators.py b/datafaker/generators.py index aa0b20b..ee0add2 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -11,7 +11,7 @@ from dataclasses import dataclass from functools import lru_cache from itertools import chain, combinations -from typing import Any, Callable, Iterable, MutableSequence, Sequence, TypeVar, Union +from typing import Any, Callable, Iterable, Sequence, Union import mimesis import mimesis.locales @@ -21,7 +21,7 @@ from typing_extensions import Self from datafaker.base import DistributionGenerator -from datafaker.utils import logger +from datafaker.utils import logger, T numeric = Union[int, float] @@ -1670,13 +1670,14 @@ def _actual_kwargs_with_combinations( "name": "constant", "params": {"value": [None] * len(partition.excluded_columns)}, } - if not partition.excluded_columns: + covariates = { + "covs": partition.covariates, + } + if not partition.constant_outputs: return { "count": count, "name": self._function_name, - "params": { - "covs": partition.covariates, - }, + "params": covariates, } return { "count": count, @@ -1684,9 +1685,7 @@ def _actual_kwargs_with_combinations( "params": { "constants_at": partition.constant_outputs, "subgen": self._function_name, - "params": { - "covs": partition.covariates, - }, + "params": covariates, }, } @@ -1718,9 +1717,6 @@ def is_numeric(col: Column) -> bool: return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys -T = TypeVar("T") - - def powerset(input: list[T]) -> Iterable[Iterable[T]]: """Returns a list of all sublists of""" return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) @@ -1780,7 +1776,7 @@ class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): SUPPRESS_COUNT = 5 EMPTY_RESULT = [ RowMapping( - parent=sqlalchemy.engine.result.ResultMetaData(), + parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), processors=None, key_to_index={"count": 0}, data=(0,), @@ -1942,10 +1938,11 @@ def _execute_partition_queries( """ found_nonzero = False for rp in partitions.values(): - rp.covariates = connection.execute(text(rp.query)).mappings().fetchall() - if not rp.covariates or rp.covariates[0]["count"] is None: + covs = connection.execute(text(rp.query)).mappings().fetchall() + if not covs or covs.count == 0 or covs[0]["count"] is None: rp.covariates = self.EMPTY_RESULT else: + rp.covariates = covs found_nonzero = True return found_nonzero diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 47f918c..580fa3a 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -1673,14 +1673,16 @@ def _get_custom_queries_from( ) -> None: if type(nominal) is str: src_stat_groups = self.SRC_STAT_RE.search(nominal) + # Do we have a SRC_STAT reference? if src_stat_groups: + # Get its name cq_key = src_stat_groups.group(1) - if cq_key not in out: - out[cq_key] = [] + # Are we pulling a specific part of this result? sub = src_stat_groups.group(3) if sub: actual = {sub: actual} - out[cq_key].append(actual) + else: + out[cq_key] = actual elif type(nominal) is list and type(actual) is list: for i in range(min(len(nominal), len(actual))): self._get_custom_queries_from(out, nominal[i], actual[i]) @@ -1780,10 +1782,11 @@ def do_propose(self, _arg: str) -> None: ) def do_p(self, arg: str) -> None: - """Synonym for propose""" + """Synonym for propose.""" self.do_propose(arg) def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: + """Find a generator by name from the list of proposals.""" for gen in self._get_generator_proposals(): if gen.name() == gen_name: return gen @@ -1792,7 +1795,7 @@ def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: def do_set(self, arg: str) -> None: """ Set one of the proposals as a generator. - Takes a single integer argument. + :param arg: A single integer (as a string). """ if arg.isdigit() and not self._generators_valid(): self.print("Please run 'propose' before 'set '") @@ -1820,9 +1823,7 @@ def do_set(self, arg: str) -> None: self._go_next() def set_generator(self, gen: Generator | None) -> None: - """ - Set the current column's generator. - """ + """Set the current column's generator.""" (table, gen_info) = self.get_table_and_generator() if table is None: self.print("Error: no table") @@ -1833,18 +1834,21 @@ def set_generator(self, gen: Generator | None) -> None: gen_info.gen = gen def do_s(self, arg: str) -> None: - """Synonym for set""" + """Synonym for set.""" self.do_set(arg) def do_unset(self, _arg: str) -> None: - """ - Remove any generator set for this column. - """ + """Remove any generator set for this column.""" self.set_generator(None) self._go_next() def do_merge(self, arg: str) -> None: - """Add this column(s) to the specified column(s), so one generator covers them all.""" + """ + Add this column(s) to the specified column(s). + + After this, one generator will cover them all. + :param arg: space separated list of column names to merge. + """ cols = arg.split() if not cols: self.print("Error: merge requires a column argument") @@ -1899,6 +1903,7 @@ def do_merge(self, arg: str) -> None: def complete_merge( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Complete column names.""" last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: @@ -1954,6 +1959,7 @@ def do_unmerge(self, arg: str) -> None: def complete_unmerge( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Complete column names to unmerge.""" last_arg = text.split()[-1] table_entry: GeneratorCmdTableEntry | None = self.get_table() if table_entry is None: @@ -1969,7 +1975,7 @@ def update_config_generators( src_dsn: str, src_schema: str | None, metadata: MetaData, - config: Mapping[str, Any], + config: MutableMapping[str, Any], spec_path: Path | None, ) -> Mapping[str, Any]: """ @@ -1981,7 +1987,7 @@ def update_config_generators( :param src_dsn: Address of the source database :param src_schema: Name of the source database schema to read from :param metadata: SQLAlchemy representation of the source database - :param config: Existing configuration (will not be destructively updated) + :param config: Existing configuration (will be destructively updated) :param spec_path: The path of the CSV file containing the specification :return: Updated configuration. """ diff --git a/datafaker/main.py b/datafaker/main.py index 65b0102..454cf44 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -1,16 +1,17 @@ """Entrypoint for the datafaker package.""" import asyncio +import importlib import io import json import sys from enum import Enum -from importlib import metadata from pathlib import Path from typing import Any, Final, Optional import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate +from sqlalchemy import MetaData from typer import Argument, Exit, Option, Typer from datafaker.create import create_db_data, create_db_tables, create_db_vocab @@ -81,7 +82,13 @@ def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: return meta_dict -def load_metadata(orm_file_name: str, config: dict | None = None) -> Any: +def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: + """ + Load metadata from ``orm.yaml`` + :param orm_file_name: ``orm.yaml`` or alternative name to load metadata from. + :param config: Used to exclude tables that are marked as ``ignore: true``. + :return: SQLAlchemy MetaData object representing the database described by the loaded file. + """ meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, None) @@ -548,16 +555,16 @@ def remove_tables( class TableType(str, Enum): - all = "all" - vocab = "vocab" - generated = "generated" + ALL = "all" + VOCAB = "vocab" + GENERATED = "generated" @app.command() def list_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), - tables: TableType = Option(TableType.generated, help="Which tables to list"), + tables: TableType = Option(TableType.GENERATED, help="Which tables to list"), ) -> None: """List the names of tables described in the metadata file.""" config = read_config_file(config_file) if config_file is not None else {} @@ -568,9 +575,9 @@ def list_tables( for (table_name, table_config) in config.get("tables", {}).items() if get_flag(table_config, "vocabulary_table") } - if tables == TableType.all: + if tables == TableType.ALL: names = all_table_names - elif tables == TableType.generated: + elif tables == TableType.GENERATED: names = all_table_names - vocab_table_names else: names = vocab_table_names @@ -584,7 +591,7 @@ def version() -> None: logger.info( "%s version %s", __package__, - metadata.version(__package__), + importlib.metadata.version(__package__), ) diff --git a/datafaker/make.py b/datafaker/make.py index ac7cdfc..e9db863 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -19,7 +19,7 @@ from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from sqlalchemy.schema import Column, Table -from sqlalchemy.sql import Executable, sqltypes, type_api +from sqlalchemy.sql import Executable, sqltypes from typing_extensions import Self from datafaker import providers @@ -825,6 +825,12 @@ async def make_src_stats( async def make_src_stats_connection( config: Mapping, db_conn: DbConnection, metadata: MetaData ) -> dict[str, dict[str, Any]]: + """ + Make the ``src-stats.yaml`` file given the database connection to read from. + :param config: configuration from ``config.yaml``. + :param db_conn: Source database connection. + :param metadata: Source database metadata from ``orm.yaml``. + """ date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) results = await asyncio.gather( diff --git a/datafaker/providers.py b/datafaker/providers.py index 03e6cbe..65abf06 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -30,6 +30,7 @@ def column_value( return None def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: + """Initialise the column value provider.""" super().__init__(seed=seed, **kwargs) self.accumulators: dict[str, int] = {} diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index d407d49..0b96e34 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,5 +1,6 @@ +"""Convert between a Python dict describing a database schema and a SQLAlchemy MetaData.""" import typing -from typing import Callable, Protocol +from typing import Callable import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table @@ -8,7 +9,7 @@ from datafaker.utils import make_foreign_key_name -table_t = dict[str, typing.Any] +TableT = dict[str, typing.Any] # We will change this to parsy.Parser when parsy exports its types properly @@ -17,7 +18,8 @@ def simple(type_: type) -> ParserType: """ - Parses a simple sqltypes type. + Get a parser for a simple sqltypes type. + For example, simple(sqltypes.UUID) takes the string "UUID" and outputs a UUID class, or fails with any other string. """ @@ -26,14 +28,15 @@ def simple(type_: type) -> ParserType: def integer() -> ParserType: """ - Parses an integer, outputting that integer. + Get a parser for an integer, outputting that integer. """ return parsy.regex(r"-?[0-9]+").map(int) def integer_arguments() -> ParserType: """ - Parses a list of integers. + Get a parser for a list of integers. + The integers are surrounded by brackets and separated by a comma and space. """ @@ -44,6 +47,8 @@ def integer_arguments() -> ParserType: def numeric_type(type_: type) -> ParserType: """ + Make a parser for a SQL numeric type. + Parses TYPE_NAME, TYPE_NAME(2) or TYPE_NAME(2,3) passing any arguments to the TYPE_NAME constructor. """ @@ -53,12 +58,16 @@ def numeric_type(type_: type) -> ParserType: def string_type(type_: type) -> ParserType: + """ + Make a parser for a SQL string type. + + Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" + or TYPE_NAME(32) COLLATE "fr" + """ + @parsy.generate(type_.__name__) def st_parser() -> typing.Generator[ParserType, None, typing.Any]: - """ - Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME COLLATE "fr" - or TYPE_NAME(32) COLLATE "fr" - """ + """Parse the specific type.""" yield parsy.string(type_.__name__) length: int | None = yield ( parsy.string("(") >> integer() << parsy.string(")") @@ -72,12 +81,22 @@ def st_parser() -> typing.Generator[ParserType, None, typing.Any]: def time_type(type_: type, pg_type: type) -> ParserType: + """ + Make a parser for a SQL date/time type. + + Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE + or TYPE_NAME(32) WITH TIME ZONE + + :param type_: The SQLAlchemy type we would like to parse. + :param pg_type: The PostgreSQL type we would like to parse if precision + or timezone is provided. + :return: ``type_`` if neither precision nor timezone are provided in the + parsed text, ``pg_type(precision, timezone)`` otherwise. + """ + @parsy.generate(type_.__name__) def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: - """ - Parses TYPE_NAME, TYPE_NAME(32), TYPE_NAME WITH TIME ZONE - or TYPE_NAME(32) WITH TIME ZONE - """ + """Parse the actual type.""" yield parsy.string(type_.__name__) precision: int | None = yield ( parsy.string("(") >> integer() << parsy.string(")") @@ -130,6 +149,11 @@ def pgt_parser() -> typing.Generator[ParserType, None, typing.Any]: @parsy.generate def type_parser() -> ParserType: + """ + Make a parser for a simple type or an array. + + Arrays produce a PostgreSQL-specific type. + """ base = yield SIMPLE_TYPE_PARSER dimensions = yield parsy.string("[]").many().map(len) if dimensions == 0: @@ -138,6 +162,11 @@ def type_parser() -> ParserType: def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: + """ + Produce a dict description of a column. + :param column: The SQLAlchemy column to translate. + :param dialect: The SQL dialect in which to render the type name. + """ type_ = column.type if isinstance(type_, postgresql.DOMAIN): # Instead of creating a restricted type, we'll just use the base type. @@ -165,6 +194,20 @@ def dict_to_column( rep: dict, ignore_fk: Callable[[str], bool], ) -> Column: + """ + Produce column from aspects of its dict description. + :param table_name: The name of the table the column appears in. + :param col_name: The name of the column. + :param rep: The dict description of the column. + :ignore_fk: A predicate, called with the name of any foreign key target + (in other words, the name of any table referred to by this column). If it + returns True, this foreign key constraint will not be applied to the + returned column. This is useful in a situation where we want a foreign + key constraint to be present when we are determining what generators + might be appropriate for it, but we don't want the foreign key constraint + actually applied to the destination database because (for example) the + target table will be ignored. + """ type_sql = rep["type"] try: type_ = type_parser.parse(type_sql) @@ -193,21 +236,20 @@ def dict_to_column( def dict_to_unique(rep: dict) -> schema.UniqueConstraint: + """Make a uniqueness constraint from its dict representation.""" return schema.UniqueConstraint(*rep.get("columns", []), name=rep.get("name", None)) def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: + """Render a dict representation of a uniqueness constraint.""" return { "name": constraint.name, "columns": [str(col.name) for col in constraint.columns], } -def table_to_dict(table: Table, dialect: Dialect) -> table_t: - """ - Converts a SQL Alchemy Table object into a - Python object ready for conversion to YAML. - """ +def table_to_dict(table: Table, dialect: Dialect) -> TableT: + """Converts a SQL Alchemy Table object into a Python dict.""" return { "columns": { str(column.key): column_to_dict(column, dialect) @@ -224,9 +266,10 @@ def table_to_dict(table: Table, dialect: Dialect) -> table_t: def dict_to_table( name: str, meta: MetaData, - table_dict: table_t, + table_dict: TableT, ignore_fk: Callable[[str], bool], ) -> Table: + """Create a Table from its description.""" return Table( name, meta, @@ -255,7 +298,7 @@ def metadata_to_dict( } -def should_ignore_fk(fk: str, tables_dict: dict[str, table_t]) -> bool: +def should_ignore_fk(fk: str, tables_dict: dict[str, TableT]) -> bool: """ Tell if this foreign key should be ignored because it points to an ignored table. diff --git a/mypy.ini b/mypy.ini index 86ff2fb..c2ea784 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ # Global options: [mypy] -python_version = 3.9 +python_version = 3.10 disallow_untyped_defs = True disallow_any_unimported = True no_implicit_optional = True diff --git a/tests/test_functional.py b/tests/test_functional.py index 394691b..e60baa1 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -117,8 +117,18 @@ def test_workflow_minimal_args(self) -> None: self.assertNoException(completed_process) self.assertEqual( { - "Unsupported SQLAlchemy type CIDR for column column_with_unusual_type. Setting this column to NULL always, you may want to configure a row generator for it instead.", - "Unsupported SQLAlchemy type BIT for column column_with_unusual_type_and_length. Setting this column to NULL always, you may want to configure a row generator for it instead.", + ( + "Unsupported SQLAlchemy type CIDR for column " + "column_with_unusual_type. Setting this column to NULL " + "always, you may want to configure a row generator for " + "it instead." + ), + ( + "Unsupported SQLAlchemy type BIT for column " + "column_with_unusual_type_and_length. Setting this column " + "to NULL always, you may want to configure a row generator " + "for it instead." + ), }, set(completed_process.stderr.split("\n")) - {""}, ) @@ -309,7 +319,10 @@ def test_workflow_maximal_args(self) -> None: self.assertSetEqual( { "Dropping constraint concept_concept_type_id_fkey from table concept", - "Dropping constraint ref_to_unignorable_table_ref_fkey from table ref_to_unignorable_table", + ( + "Dropping constraint ref_to_unignorable_table_ref_fkey from " + "table ref_to_unignorable_table" + ), "Dropping constraint concept_type_mitigation_type_id_fkey from table concept_type", "Restoring foreign key constraint concept_concept_type_id_fkey", "Restoring foreign key constraint ref_to_unignorable_table_ref_fkey", @@ -408,8 +421,14 @@ def test_workflow_maximal_args(self) -> None: 'Truncating vocabulary table "mitigation_type".', 'Truncating vocabulary table "empty_vocabulary".', "Vocabulary tables truncated.", - "Dropping constraint concept_type_mitigation_type_id_fkey from table concept_type", - "Dropping constraint ref_to_unignorable_table_ref_fkey from table ref_to_unignorable_table", + ( + "Dropping constraint concept_type_mitigation_type_id_fkey " + "from table concept_type" + ), + ( + "Dropping constraint ref_to_unignorable_table_ref_fkey from " + "table ref_to_unignorable_table" + ), "Dropping constraint concept_concept_type_id_fkey from table concept", "Restoring foreign key constraint concept_type_mitigation_type_id_fkey", "Restoring foreign key constraint ref_to_unignorable_table_ref_fkey", diff --git a/tests/test_interactive.py b/tests/test_interactive.py index e1de3b3..a7803f0 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -286,7 +286,7 @@ def test_list_tables(self) -> None: person_listed = False unique_constraint_test_listed = False no_pk_test_listed = False - for text, args, kwargs in tc.messages: + for _text, args, _kwargs in tc.messages: if args[2] == "person": self.assertFalse(person_listed) person_listed = True @@ -477,13 +477,13 @@ def test_null_configuration(self) -> None: "tables": None, } with self._get_cmd(config) as gc: - TABLE = "model" - gc.do_next(f"{TABLE}.name") + table = "model" + gc.do_next(f"{table}.name") gc.do_propose("") gc.do_compare("") gc.do_set("1") gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) def test_null_table_configuration(self) -> None: """Test that a table having null configuration does not break.""" @@ -493,12 +493,12 @@ def test_null_table_configuration(self) -> None: } } with self._get_cmd(config) as gc: - TABLE = "model" - gc.do_next(f"{TABLE}.name") + table = "model" + gc.do_next(f"{table}.name") gc.do_propose("") gc.do_set("1") gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) def test_prompts(self) -> None: """Test that the prompts follow the names of the columns and assigned generators.""" @@ -543,94 +543,94 @@ def test_prompts(self) -> None: def test_set_generator_mimesis(self) -> None: """Test that we can set one generator to a mimesis generator.""" with self._get_cmd({}) as gc: - TABLE = "model" - COLUMN = "name" - GENERATOR = "person.first_name" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "model" + column = "name" + generator = "person.first_name" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"generic.{GENERATOR}"][0])) + gc.do_set(str(proposals[f"generic.{generator}"][0])) gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) self.assertDictEqual( - gc.config["tables"][TABLE]["row_generators"][0], - {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, + gc.config["tables"][table]["row_generators"][0], + {"name": f"generic.{generator}", "columns_assigned": [column]}, ) def test_set_generator_distribution(self) -> None: """Test that we can set one generator to gaussian.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.normal" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "frequency" + generator = "dist_gen.normal" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) + gc.do_set(str(proposals[generator][0])) gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column]) self.assertDictEqual( row_gen["kwargs"], { - "mean": f'SRC_STATS["auto__{TABLE}"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__{TABLE}"]["results"][0]["stddev__{COLUMN}"]', + "mean": f'SRC_STATS["auto__{table}"]["results"][0]["mean__{column}"]', + "sd": f'SRC_STATS["auto__{table}"]["results"][0]["stddev__{column}"]', }, ) self.assertEqual(len(gc.config["src-stats"]), 1) self.assertSetEqual( set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") + self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", ) def test_set_generator_distribution_directly(self) -> None: """Test that we can set one generator to gaussian without going through propose.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.normal" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "frequency" + generator = "dist_gen.normal" + gc.do_next(f"{table}.{column}") gc.reset() - gc.do_set(GENERATOR) + gc.do_set(generator) self.assertListEqual(gc.messages, []) gc.do_quit("") self.assertEqual(len(gc.config["src-stats"]), 1) self.assertSetEqual( set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{TABLE}") + self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({COLUMN}) AS mean__{COLUMN}, STDDEV({COLUMN}) AS stddev__{COLUMN} FROM {TABLE}", + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", ) def test_set_generator_choice(self) -> None: """Test that we can set one generator to uniform choice.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "frequency" - GENERATOR = "dist_gen.choice" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "frequency" + generator = "dist_gen.choice" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) + gc.do_set(str(proposals[generator][0])) gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN]) + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column]) self.assertDictEqual( row_gen["kwargs"], { - "a": f'SRC_STATS["auto__{TABLE}__{COLUMN}"]["results"]', + "a": f'SRC_STATS["auto__{table}__{column}"]["results"]', }, ) self.assertEqual(len(gc.config["src-stats"]), 1) @@ -638,108 +638,108 @@ def test_set_generator_choice(self) -> None: set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) self.assertEqual( - gc.config["src-stats"][0]["name"], f"auto__{TABLE}__{COLUMN}" + gc.config["src-stats"][0]["name"], f"auto__{table}__{column}" ) self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT {COLUMN} AS value FROM {TABLE} WHERE {COLUMN} IS NOT NULL GROUP BY value ORDER BY COUNT({COLUMN}) DESC", + f"SELECT {column} AS value FROM {table} WHERE {column} IS NOT NULL GROUP BY value ORDER BY COUNT({column}) DESC", ) def test_weighted_choice_generator_generates_choices(self) -> None: """Test that propose and compare show weighted_choice's values.""" with self._get_cmd({}) as gc: - TABLE = "string" - COLUMN = "position" - GENERATOR = "dist_gen.weighted_choice" - VALUES = {1, 2, 3, 4, 5, 6} - gc.do_next(f"{TABLE}.{COLUMN}") + table = "string" + column = "position" + generator = "dist_gen.weighted_choice" + values = {1, 2, 3, 4, 5, 6} + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gen_proposal = proposals[GENERATOR] - self.assertSubset(set(gen_proposal[2]), {str(v) for v in VALUES}) + gen_proposal = proposals[generator] + self.assertSubset(set(gen_proposal[2]), {str(v) for v in values}) gc.do_compare(str(gen_proposal[0])) - col_heading = f"{gen_proposal[0]}. {GENERATOR}" + col_heading = f"{gen_proposal[0]}. {generator}" self.assertIn(col_heading, gc.columns) - self.assertSubset(set(gc.columns[col_heading]), VALUES) + self.assertSubset(set(gc.columns[col_heading]), values) def test_merge_columns(self) -> None: """Test that we can merge columns and set a multivariate generator""" - TABLE = "string" - COLUMN_1 = "frequency" - COLUMN_2 = "position" - GENERATOR_TO_DISCARD = "dist_gen.choice" - GENERATOR = "dist_gen.multivariate_normal" + table = "string" + column_1 = "frequency" + column_2 = "position" + generator_to_discard = "dist_gen.choice" + generator = "dist_gen.multivariate_normal" with self._get_cmd({}) as gc: - gc.do_next(f"{TABLE}.{COLUMN_2}") + gc.do_next(f"{table}.{column_2}") gc.do_propose("") proposals = gc.get_proposals() # set a generator, but this should not exist after merging - gc.do_set(str(proposals[GENERATOR_TO_DISCARD][0])) - gc.do_next(f"{TABLE}.{COLUMN_1}") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) + gc.do_set(str(proposals[generator_to_discard][0])) + gc.do_next(f"{table}.{column_1}") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) gc.do_propose("") proposals = gc.get_proposals() # set a generator, but this should not exist either - gc.do_set(str(proposals[GENERATOR_TO_DISCARD][0])) + gc.do_set(str(proposals[generator_to_discard][0])) gc.do_previous("") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) - gc.do_merge(COLUMN_2) - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) + gc.do_merge(column_2) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) gc.reset() gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[GENERATOR][0])) + gc.do_set(str(proposals[generator][0])) gc.do_quit("") - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], GENERATOR) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN_1, COLUMN_2]) + self.assertEqual(row_gen["name"], generator) + self.assertListEqual(row_gen["columns_assigned"], [column_1, column_2]) def test_unmerge_columns(self) -> None: """Test that we can unmerge columns and generators are removed""" - TABLE = "string" - COLUMN_1 = "frequency" - COLUMN_2 = "position" - COLUMN_3 = "model_id" - REMAINING_GEN = "gen3" + table = "string" + column_1 = "frequency" + column_2 = "position" + column_3 = "model_id" + remaining_gen = "gen3" config = { "tables": { - TABLE: { + table: { "row_generators": [ - {"name": "gen1", "columns_assigned": [COLUMN_1, COLUMN_2]}, - {"name": REMAINING_GEN, "columns_assigned": [COLUMN_3]}, + {"name": "gen1", "columns_assigned": [column_1, column_2]}, + {"name": remaining_gen, "columns_assigned": [column_3]}, ] } } } with self._get_cmd(config) as gc: - gc.do_next(f"{TABLE}.{COLUMN_2}") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) - gc.do_unmerge(COLUMN_1) - self.assertIn(TABLE, gc.prompt) - self.assertNotIn(COLUMN_1, gc.prompt) - self.assertIn(COLUMN_2, gc.prompt) + gc.do_next(f"{table}.{column_2}") + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) + gc.do_unmerge(column_1) + self.assertIn(table, gc.prompt) + self.assertNotIn(column_1, gc.prompt) + self.assertIn(column_2, gc.prompt) # Next generator should be the unmerged one gc.do_next("") - self.assertIn(TABLE, gc.prompt) - self.assertIn(COLUMN_1, gc.prompt) - self.assertNotIn(COLUMN_2, gc.prompt) + self.assertIn(table, gc.prompt) + self.assertIn(column_1, gc.prompt) + self.assertNotIn(column_2, gc.prompt) gc.do_quit("") # Both generators should have disappeared - row_gens = gc.config["tables"][TABLE]["row_generators"] + row_gens = gc.config["tables"][table]["row_generators"] self.assertEqual(len(row_gens), 1) row_gen = row_gens[0] - self.assertEqual(row_gen["name"], REMAINING_GEN) - self.assertListEqual(row_gen["columns_assigned"], [COLUMN_3]) + self.assertEqual(row_gen["name"], remaining_gen) + self.assertListEqual(row_gen["columns_assigned"], [column_3]) def test_old_generators_remain(self) -> None: """Test that we can set one generator and keep an old one.""" @@ -766,18 +766,18 @@ def test_old_generators_remain(self) -> None: ], } with self._get_cmd(config) as gc: - TABLE = "model" - COLUMN = "name" - GENERATOR = "person.first_name" - gc.do_next(f"{TABLE}.{COLUMN}") + table = "model" + column = "name" + generator = "person.first_name" + gc.do_next(f"{table}.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"generic.{GENERATOR}"][0])) + gc.do_set(str(proposals[f"generic.{generator}"][0])) gc.do_quit("") - self.assertEqual(len(gc.config["tables"][TABLE]["row_generators"]), 1) + self.assertEqual(len(gc.config["tables"][table]["row_generators"]), 1) self.assertDictEqual( - gc.config["tables"][TABLE]["row_generators"][0], - {"name": f"generic.{GENERATOR}", "columns_assigned": [COLUMN]}, + gc.config["tables"][table]["row_generators"][0], + {"name": f"generic.{generator}", "columns_assigned": [column]}, ) row_gens = gc.config["tables"]["string"]["row_generators"] self.assertEqual(len(row_gens), 1) @@ -795,7 +795,7 @@ def test_old_generators_remain(self) -> None: self.assertSetEqual( set(gc.config["src-stats"][0].keys()), {"comments", "name", "query"} ) - self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__string") + self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") self.assertEqual( gc.config["src-stats"][0]["query"], "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", @@ -829,31 +829,31 @@ def test_aggregate_queries_merge(self) -> None: ], } with self._get_cmd(copy.deepcopy(config)) as gc: - COLUMN = "position" - GENERATOR = "dist_gen.uniform_ms" - gc.do_next(f"string.{COLUMN}") + column = "position" + generator = "dist_gen.uniform_ms" + gc.do_next(f"string.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"{GENERATOR}"][0])) + gc.do_set(str(proposals[f"{generator}"][0])) gc.do_quit("") row_gens: list[dict[str, Any]] = gc.config["tables"]["string"][ "row_generators" ] self.assertEqual(len(row_gens), 2) - if row_gens[0]["name"] == GENERATOR: + if row_gens[0]["name"] == generator: row_gen0 = row_gens[0] row_gen1 = row_gens[1] else: row_gen0 = row_gens[1] row_gen1 = row_gens[0] - self.assertEqual(row_gen0["name"], GENERATOR) + self.assertEqual(row_gen0["name"], generator) self.assertEqual(row_gen1["name"], "dist_gen.normal") - self.assertListEqual(row_gen0["columns_assigned"], [COLUMN]) + self.assertListEqual(row_gen0["columns_assigned"], [column]) self.assertDictEqual( row_gen0["kwargs"], { - "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{COLUMN}"]', - "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{COLUMN}"]', + "mean": f'SRC_STATS["auto__string"]["results"][0]["mean__{column}"]', + "sd": f'SRC_STATS["auto__string"]["results"][0]["stddev__{column}"]', }, ) self.assertListEqual(row_gen1["columns_assigned"], ["frequency"]) @@ -877,8 +877,8 @@ def test_aggregate_queries_merge(self) -> None: { "AVG(frequency) AS mean__frequency", "STDDEV(frequency) AS stddev__frequency", - f"AVG({COLUMN}) AS mean__{COLUMN}", - f"STDDEV({COLUMN}) AS stddev__{COLUMN}", + f"AVG({column}) AS mean__{column}", + f"STDDEV({column}) AS stddev__{column}", }, ) @@ -954,12 +954,12 @@ def test_existing_configuration_remains(self) -> None: ], } with self._get_cmd(config) as gc: - COLUMN = "position" - GENERATOR = "dist_gen.uniform_ms" - gc.do_next(f"string.{COLUMN}") + column = "position" + generator = "dist_gen.uniform_ms" + gc.do_next(f"string.{column}") gc.do_propose("") proposals = gc.get_proposals() - gc.do_set(str(proposals[f"{GENERATOR}"][0])) + gc.do_set(str(proposals[f"{generator}"][0])) gc.do_quit("") src_stats = {stat["name"]: stat["query"] for stat in gc.config["src-stats"]} self.assertEqual(src_stats["kraken"], config["src-stats"][0]["query"]) @@ -1170,8 +1170,8 @@ def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: def test_set_missingness_to_sampled(self) -> None: """Test that we can set one table to sampled missingness.""" with self._get_cmd({}) as mc: - TABLE = "signature_model" - mc.do_next(TABLE) + table = "signature_model" + mc.do_next(table) mc.do_counts("") self.assertListEqual( mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] @@ -1181,7 +1181,7 @@ def test_set_missingness_to_sampled(self) -> None: mc.do_sampled("") mc.do_quit("") self.assertListEqual( - mc.config["tables"][TABLE]["missingness_generators"], + mc.config["tables"][table]["missingness_generators"], [ { "columns": ["player_id", "based_on"], @@ -1199,9 +1199,12 @@ def test_set_missingness_to_sampled(self) -> None: self.assertEqual( mc.config["src-stats"][0]["query"], ( - "SELECT COUNT(*) AS row_count, player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null, based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000) AS __t GROUP BY player_id__is_null, based_on__is_null" + "SELECT COUNT(*) AS row_count," + " player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null," + " based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000)" + " AS __t GROUP BY player_id__is_null, based_on__is_null" ), ) @@ -1325,7 +1328,7 @@ def test_dist_gen_sampled_produces_ordered_src_stats(self) -> None: ] self.assertListEqual(based_ons, [1, 3, 2]) - def assertAreTruncatedTo(self, xs: Iterable[str], length: int) -> None: + def assert_are_truncated_to(self, xs: Iterable[str], length: int) -> None: """ Check that none of the strings are longer than ``length`` (after removing surrounding quotes). @@ -1339,62 +1342,71 @@ def assertAreTruncatedTo(self, xs: Iterable[str], length: int) -> None: def test_varchar_ns_are_truncated(self) -> None: """Tests that mimesis generators for VARCHAR(N) truncate to N characters""" - GENERATOR = "generic.text.quote" - TABLE = "signature_model" - COLUMN = "name" + generator = "generic.text.quote" + table = "signature_model" + column = "name" with self._get_cmd({}) as gc: - gc.do_next(f"{TABLE}.{COLUMN}") + gc.do_next(f"{table}.{column}") gc.reset() gc.do_propose("") proposals = gc.get_proposals() - quotes = [k for k in proposals.keys() if k.startswith(GENERATOR)] + quotes = [k for k in proposals.keys() if k.startswith(generator)] self.assertEqual(len(quotes), 1) prop = proposals[quotes[0]] - self.assertAreTruncatedTo(prop[2], 20) + self.assert_are_truncated_to(prop[2], 20) gc.reset() gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. {quotes[0]}" gc.do_set(str(prop[0])) self.assertIn(col_heading, gc.columns) - self.assertAreTruncatedTo(gc.columns[col_heading], 20) + self.assert_are_truncated_to(gc.columns[col_heading], 20) gc.do_quit("") config = gc.config self.generate_data(config, num_passes=15) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[TABLE].c[COLUMN]) + stmt = select(self.metadata.tables[table].c[column]) rows = conn.execute(stmt).scalars().fetchall() - self.assertAreTruncatedTo(rows, 20) + self.assert_are_truncated_to(rows, 20) @dataclass class Stat: + """Mean and variance calculator.""" + n: int = 0 x: float = 0 x2: float = 0 def add(self, x: float) -> None: + """Add one datum.""" self.n += 1 self.x += x self.x2 += x * x def count(self) -> int: + """Get the number of data added.""" return self.n def x_mean(self) -> float: + """Get the mean of the added data.""" return self.x / self.n def x_var(self) -> float: + """Get the variance of the added data.""" x = self.x return (self.x2 - x * x / self.n) / (self.n - 1) @dataclass class Correlation(Stat): + """Mean, variance and covariance.""" + y: float = 0 y2: float = 0 xy: float = 0 def add2(self, x: float, y: float) -> None: + """Add a 2D data point.""" self.n += 1 self.x += x self.x2 += x * x @@ -1403,13 +1415,16 @@ def add2(self, x: float, y: float) -> None: self.xy += x * y def y_mean(self) -> float: + """Get the mean of the second parts of the added points.""" return self.y / self.n def y_var(self) -> float: + """Get the variance of the second parts of the added points.""" y = self.y return (self.y2 - y * y / self.n) / (self.n - 1) def covar(self) -> float: + """Get the covariance of the two parts of the added points.""" return (self.xy - self.x * self.y / self.n) / (self.n - 1) @@ -1702,7 +1717,7 @@ def test_non_interactive_configure_generators( """ test that we can set generators from a CSV file """ - config: Mapping[str, Any] = {} + config: MutableMapping[str, Any] = {} spec_csv = Mock(return_value="mock spec.csv file") update_config_generators( self.dsn, self.schema_name, self.metadata, config, spec_csv diff --git a/tests/test_utils.py b/tests/test_utils.py index b34e922..54c167a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -288,6 +288,7 @@ def test_generators_require_stats(self) -> None: @patch("datafaker.utils.logger") def test_testing_generators_finds_syntax_errors(self, logger: MagicMock) -> None: + """Test that looking for ``SRC_STATS`` references finds Python syntax errors.""" generators_require_stats( { "story_generators": [ From 42fb24a0420b5f2bd83ee5996ebdec706a2fdc0b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 9 Oct 2025 18:58:11 +0100 Subject: [PATCH 17/35] Many, many cleanups. --- CONTRIBUTING.md | 3 +- datafaker/base.py | 107 +++++++++-- datafaker/dump.py | 4 +- datafaker/generators.py | 299 +++++++++++++++++++---------- datafaker/interactive.py | 327 +++++++++++++++++++++++--------- datafaker/main.py | 48 +++-- datafaker/make.py | 76 +++++--- datafaker/providers.py | 2 +- datafaker/serialize_metadata.py | 39 ++-- datafaker/utils.py | 109 +++++++++-- tests/test_dump.py | 6 +- tests/test_functional.py | 85 ++++++--- tests/test_interactive.py | 37 +++- tests/test_main.py | 2 +- tests/test_remove.py | 4 + tests/test_rst.py | 16 +- 16 files changed, 836 insertions(+), 328 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 259ebe8..8d7b879 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,7 +55,8 @@ These tests do not currently work, and will be replaced by unit tests. Functional tests require PostgreSQL to be installed. - *WARNING: Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file.* +..warning:: + Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file. ## Building documentation locally diff --git a/datafaker/base.py b/datafaker/base.py index 8270ccb..4ceb2af 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -27,6 +27,7 @@ @functools.cache def zipf_weights(size: int) -> list[float]: + """Get the weights of a Zipf distribution of a given size.""" total = sum(map(lambda n: 1 / n, range(1, size + 1))) return [1 / (n * total) for n in range(1, size + 1)] @@ -36,6 +37,7 @@ def merge_with_constants( ) -> Generator[T, None, None]: """ Merge a list of items with other items that must be placed at certain indices. + :param constants_at: A map of indices to objects that must be placed at those indices. :param xs: Items that fill in the gaps left by ``constants_at``. @@ -62,35 +64,88 @@ def merge_with_constants( class NothingToGenerateException(Exception): + """Exception thrown when no value can be generated.""" + def __init__(self, message: str): + """Initialise the exception with a human-readable message.""" super().__init__(message) class DistributionGenerator: + """An object that can produce values from various distributions.""" + root3 = math.sqrt(3) def __init__(self) -> None: + """Initialise the DistributionGenerator.""" self.np_gen = np.random.default_rng() def uniform(self, low: float, high: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param low: The lowest value that can be chosen. + :param high: The highest value that can be chosen. + :return: The output value. + """ return random.uniform(float(low), float(high)) def uniform_ms(self, mean: float, sd: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ m = float(mean) h = self.root3 * float(sd) return random.uniform(m - h, m + h) def normal(self, mean: float, sd: float) -> float: + """ + Choose a value according to a Gaussian (normal) distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ return random.normalvariate(float(mean), float(sd)) def lognormal(self, logmean: float, logsd: float) -> float: + """ + Choose a value according to a lognormal distribution. + + :param logmean: The mean of the logs of the output values. + :param logsd: The standard deviation of the logs of the output values. + :return: The output value. + """ return random.lognormvariate(float(logmean), float(logsd)) def choice(self, a: list[T]) -> T: + """ + Choose a value with equal probability. + + :param a: The list of values to output. Each element is either + the value itself, or a mapping with a key ``value`` and the key + is the value to return. + :return: The chosen value. + """ c = random.choice(a) return c["value"] if type(c) is dict and "value" in c else c def zipf_choice(self, a: list[T], n: int | None = None) -> T: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of values to output, most frequent first. + Each element is either the value itself, or a mapping with + a key ``value`` and the key is the value to return. + :return: The chosen value. + """ if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] @@ -99,9 +154,11 @@ def zipf_choice(self, a: list[T], n: int | None = None) -> T: def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ Choice weighted by the count in the original dataset. + :param a: a list of dicts, each with a ``value`` key holding the value to be returned and a ``count`` key holding the number of that value found in the original dataset + :return: The chosen ``value``. """ vs = [] counts = [] @@ -114,9 +171,19 @@ def weighted_choice(self, a: list[dict[str, Any]]) -> Any: return c def constant(self, value: T) -> T: + """Return the same value always.""" return value def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: + """ + Return an array of values chosen from the given covariates. + + :param cov: Keys are ``rank``: The number of values to output; + ``mN``: The mean of variable ``N`` (where ``N`` is between 0 and + one less than ``rank``). ``cN_M`` (where 0 < ``N`` <= ``M`` < ``rank``): + the covariance between the ``N``th and the ``M``th variables. + :return: A numpy array of results. + """ rank = int(cov["rank"]) if rank == 0: return np.empty(shape=(0,)) @@ -131,9 +198,7 @@ def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: return self.np_gen.multivariate_normal(mean, covs) def _select_group(self, alts: list[dict[str, Any]]) -> Any: - """ - Choose one of the ``alts`` weighted by their ``"count"`` elements. - """ + """Choose one of the ``alts`` weighted by their ``"count"`` elements.""" total = 0 for alt in alts: if alt["count"] < 0: @@ -204,9 +269,7 @@ def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: return out def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: - """ - Produce a list of values pulled from a set of multivariate distributions. - """ + """Produce a list of values pulled from a set of multivariate distributions.""" cov = self._select_group(covs) logger.debug("Multivariate normal group selected: %s", cov) constants = self._find_constants(cov) @@ -214,9 +277,7 @@ def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: return list(merge_with_constants(nums, constants)) def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: - """ - Produce a list of values pulled from a set of multivariate distributions. - """ + """Produce a list of values pulled from a set of multivariate distributions.""" cov = self._select_group(covs) logger.debug("Multivariate lognormal group selected: %s", cov) constants = self._find_constants(cov) @@ -233,7 +294,7 @@ def alternatives( counts: list[dict[str, int]] | None, ) -> Any: """ - A generator that picks between other generators. + Pick between other generators. :param alternative_configs: List of alternative generators. Each alternative has the following keys: "count" -- a weight for @@ -264,6 +325,17 @@ def alternatives( def with_constants_at( self, constants_at: dict[int, T], subgen: str, params: dict[str, T] ) -> list[T]: + """ + Insert constants into the results of a different generator. + + :param constants_at: A dictionary of positions and objects to insert + into the return list at those positions. + :param subgen: The name of the function to call to get the results + that will have the constants inserted into. + :param params: Keyword arguments to the ``subgen`` function. + :return: A list of results from calling ``subgen(**params)`` + with ``constants_at`` inserted in at the appropriate indices. + """ if subgen not in self.PERMITTED_SUBGENS: logger.error( "subgenerator %s is not a valid name. Valid names are %s.", @@ -277,7 +349,7 @@ def with_constants_at( def truncated_string( self, subgen_fn: Callable[..., list[T]], params: dict, length: int ) -> list[T]: - """Calls ``subgen_fn(**params)`` and truncates the results to ``length``.""" + """Call ``subgen_fn(**params)`` and truncate the results to ``length``.""" result = subgen_fn(**params) if result is None: return None @@ -358,7 +430,18 @@ def load(self, connection: Connection, base_path: Path = Path(".")) -> None: class ColumnPresence: - def sampled(self, patterns: list[dict[str, Any]]) -> set[Any]: + """Object for generators to use for missingness completely at random.""" + + def sampled(self, patterns: list[dict[str, Any]]) -> set[str]: + """ + Select a random pattern and output the non-null columns. + + :param patterns: List of outputs from missingness SQL queries. + Columns in each output: ``row_count`` is the number of rows + with this missingness pattern, then for each column + ```` there is a boolean called ``missingness__is_null``. + :return: All the names of the columns no make non-null. + """ total = 0 for pattern in patterns: total += pattern.get("row_count", 0) diff --git a/datafaker/dump.py b/datafaker/dump.py index 95f251a..2307ba4 100644 --- a/datafaker/dump.py +++ b/datafaker/dump.py @@ -1,4 +1,4 @@ -""" Data dumping functions. """ +"""Data dumping functions.""" import csv import io from typing import TYPE_CHECKING @@ -35,4 +35,4 @@ def dump_db_tables( with engine.connect() as connection: result = connection.execute(sqlalchemy.select(table)) for row in result: - csv_out.writerow(row._tuple()) + csv_out.writerow(row) diff --git a/datafaker/generators.py b/datafaker/generators.py index ee0add2..4e421af 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -1,6 +1,4 @@ -""" -Generator factories for making generators for single columns. -""" +"""Generator factories for making generators for single columns.""" import decimal import math @@ -21,7 +19,7 @@ from typing_extensions import Self from datafaker.base import DistributionGenerator -from datafaker.utils import logger, T +from datafaker.utils import T, logger numeric = Union[int, float] @@ -49,11 +47,11 @@ class Generator(ABC): @abstractmethod def function_name(self) -> str: - """The name of the generator function to put into df.py.""" + """Get the name of the generator function to put into df.py.""" def name(self) -> str: """ - The name of the generator. + Get the name of the generator. Usually the same as the function name, but can be different to distinguish between generators that have the same function but different queries. @@ -63,7 +61,8 @@ def name(self) -> str: @abstractmethod def nominal_kwargs(self) -> dict[str, str]: """ - The kwargs the generator wants to be called with. + Get the kwargs the generator wants to be called with. + The values will tend to be references to something in the src-stats.yaml file. For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will @@ -74,7 +73,7 @@ def nominal_kwargs(self) -> dict[str, str]: def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: """ - SQL clauses to add to a SELECT ... FROM {table} query. + Get the SQL clauses to add to a SELECT ... FROM {table} query. Will add to SRC_STATS["auto__{table}"] For example { @@ -94,7 +93,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: def custom_queries(self) -> dict[str, dict[str, str]]: """ - SQL queries to add to SRC_STATS. + Get the SQL queries to add to SRC_STATS. Should be used for queries that do not follow the SELECT ... FROM table format using aggregate queries, because these should use select_aggregate_clauses. @@ -114,14 +113,14 @@ def custom_queries(self) -> dict[str, dict[str, str]]: @abstractmethod def actual_kwargs(self) -> dict[str, Any]: """ - The kwargs (summary statistics) this generator is instantiated with. + Get the kwargs (summary statistics) this generator is instantiated with. + + This must match `nominal_kwargs` in structure. """ @abstractmethod def generate_data(self, count: int) -> list[Any]: - """ - Generate 'count' random data points for this column. - """ + """Generate ``count`` random data points for this column.""" def fit(self, default: float = -1) -> float: """ @@ -134,9 +133,7 @@ def fit(self, default: float = -1) -> float: class PredefinedGenerator(Generator): - """ - Generator built from an existing config.yaml. - """ + """Generator built from an existing config.yaml.""" SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") @@ -168,6 +165,7 @@ def __init__( ): """ Initialise a generator from a config.yaml. + :param config: The entire configuration. :param generator_object: The part of the configuration at tables.*.row_generators """ @@ -219,24 +217,30 @@ def __init__( } def function_name(self) -> str: + """Get the name of the generator function to call.""" return self._name def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" return self._kwn def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" return self._select_aggregate_clauses def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" return self._custom_queries def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" # Run the queries from nominal_kwargs # ... logger.error("PredefinedGenerator.actual_kwargs not implemented yet") return {} def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" # Call the function if we can. This could be tricky... # ... logger.error("PredefinedGenerator.generate_data not implemented yet") @@ -244,21 +248,19 @@ def generate_data(self, count: int) -> list[Any]: class GeneratorFactory(ABC): - """ - A factory for making generators appropriate for a database column. - """ + """A factory for making generators appropriate for a database column.""" @abstractmethod def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: - """ - Returns all the generators that might be appropriate for this column. - """ + """Get the generators appropriate to these columns.""" class Buckets: """ + Measured buckets for a real distribution. + Finds the real distribution of continuous data so that we can measure the fit of generators against it. """ @@ -272,6 +274,7 @@ def __init__( stddev: float, count: int, ): + """Initialise a Buckets object.""" with engine.connect() as connection: raw_buckets = connection.execute( text( @@ -330,15 +333,11 @@ def make_buckets( return buckets def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: - """ - Figure out the fit from bucket counts from the generator distribution. - """ + """Figure out the fit from bucket counts from the generator distribution.""" return fit_from_buckets(self.buckets, bucket_counts) def fit_from_values(self, values: list[float]) -> float: - """ - Figure out the fit from samples from the generator distribution. - """ + """Figure out the fit from samples from the generator distribution.""" buckets = [0] * 10 x = self.mean - 2 * self.stddev w = self.stddev / 2 @@ -352,12 +351,14 @@ class MultiGeneratorFactory(GeneratorFactory): """A composite factory.""" def __init__(self, factories: list[GeneratorFactory]): + """Initialise a MultiGeneratorFactory.""" super().__init__() self.factories = factories def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" return [ generator for factory in self.factories @@ -366,12 +367,14 @@ def get_generators( class MimesisGeneratorBase(Generator): + """Base class for a generator using Mimesis.""" + def __init__( self, function_name: str, ): """ - Generator from Mimesis. + Initialise a generator that uses Mimesis. :param function_name: is relative to 'generic', for example 'person.name'. """ @@ -391,13 +394,17 @@ def __init__( self._generator_function = f def function_name(self) -> str: + """Get the name of the generator function to call.""" return self._name def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [self._generator_function() for _ in range(count)] class MimesisGenerator(MimesisGeneratorBase): + """A generator using Mimesis.""" + def __init__( self, function_name: str, @@ -405,7 +412,7 @@ def __init__( buckets: Buckets | None = None, ): """ - Generator from Mimesis. + Initialise a generator using Mimesis. :param function_name: is relative to 'generic', for example 'person.name'. :param value_fn: Function to convert generator output to floats, if needed. The values @@ -423,19 +430,25 @@ def __init__( self._fit = buckets.fit_from_values(samples) def function_name(self) -> str: + """Get the name of the generator function to call.""" return self._name def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return {} def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return {} def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default if self._fit is None else self._fit class MimesisGeneratorTruncated(MimesisGenerator): + """A string generator using Mimesis that must fit within a certain number of characters.""" + def __init__( self, function_name: str, @@ -443,16 +456,20 @@ def __init__( value_fn: Callable[[Any], float] | None = None, buckets: Buckets | None = None, ): + """Initialise a MimesisGeneratorTruncated.""" self._length = length super().__init__(function_name, value_fn, buckets) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.truncated_string" def name(self) -> str: + """Get the name of the generator.""" return f"{self._name} [truncated to {self._length}]" def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "subgen_fn": self._name, "params": {}, @@ -460,6 +477,7 @@ def nominal_kwargs(self) -> dict[str, Any]: } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "subgen_fn": self._name, "params": {}, @@ -467,10 +485,13 @@ def actual_kwargs(self) -> dict[str, Any]: } def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [self._generator_function()[: self._length] for _ in range(count)] class MimesisDateTimeGenerator(MimesisGeneratorBase): + """DateTime generator using Mimesis.""" + def __init__( self, column: Column, @@ -481,6 +502,8 @@ def __init__( end: int, ) -> None: """ + Initialise a MimesisDateTimeGenerator. + :param column: The column to generate into :param function_name: The name of the mimesis function :param min_year: SQL expression extracting the minimum year @@ -499,6 +522,7 @@ def __init__( def make_singleton( _cls, column: Column, engine: Engine, function_name: str ) -> Sequence[Generator]: + """Make the appropriate generation configuration for this column.""" extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" max_year = f"MAX({extract_year})" min_year = f"MIN({extract_year})" @@ -522,18 +546,21 @@ def make_singleton( ] def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "start": self._start, "end": self._end, } def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" return { f"{self._column.name}__start": { "clause": self._min_year, @@ -546,6 +573,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: } def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [ self._generator_function(start=self._start, end=self._end) for _ in range(count) @@ -553,6 +581,7 @@ def generate_data(self, count: int) -> list[Any]: def get_column_type(column: Column) -> TypeEngine: + """Get the type of the column, generic if possible.""" try: return column.type.as_generic() except NotImplementedError: @@ -560,9 +589,7 @@ def get_column_type(column: Column) -> TypeEngine: class MimesisStringGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return strings. - """ + """All Mimesis generators that return strings.""" GENERATOR_NAMES = [ "address.calling_code", @@ -601,6 +628,7 @@ class MimesisStringGeneratorFactory(GeneratorFactory): def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -639,13 +667,12 @@ def get_generators( class MimesisFloatGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return floating point numbers. - """ + """All Mimesis generators that return floating point numbers.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -662,13 +689,12 @@ def get_generators( class MimesisDateGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return dates. - """ + """All Mimesis generators that return dates.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -679,13 +705,12 @@ def get_generators( class MimesisDateTimeGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return datetimes. - """ + """All Mimesis generators that return datetimes.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -698,13 +723,12 @@ def get_generators( class MimesisTimeGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return times. - """ + """All Mimesis generators that return times.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -715,13 +739,12 @@ def get_generators( class MimesisIntegerGeneratorFactory(GeneratorFactory): - """ - All Mimesis generators that return integers. - """ + """All Mimesis generators that return integers.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -732,27 +755,33 @@ def get_generators( def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: + """Calculate the fit by comparing a pair of lists of buckets.""" sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) return sum_diff_squared / (count * count) class ContinuousDistributionGenerator(Generator): + """Base class for generators producing continuous distributions.""" + expected_buckets: Sequence[numeric] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): + """Initialise a ContinuousDistributionGenerator.""" super().__init__() self.table_name = table_name self.column_name = column_name self.buckets = buckets def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" if self.buckets is None: return {} return { @@ -761,6 +790,7 @@ def actual_kwargs(self) -> dict[str, Any]: } def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" clauses = super().select_aggregate_clauses() return { **clauses, @@ -775,12 +805,15 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: } def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) class GaussianGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a Gaussian (normal) distribution.""" + expected_buckets = [ 0.0227, 0.0441, @@ -795,9 +828,11 @@ class GaussianGenerator(ContinuousDistributionGenerator): ] def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.normal" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [ dist_gen.normal(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -805,6 +840,8 @@ def generate_data(self, count: int) -> list[Any]: class UniformGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a uniform distribution.""" + expected_buckets = [ 0, 0.06698, @@ -819,9 +856,11 @@ class UniformGenerator(ContinuousDistributionGenerator): ] def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.uniform_ms" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [ dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) for _ in range(count) @@ -829,9 +868,7 @@ def generate_data(self, count: int) -> list[Any]: class ContinuousDistributionGeneratorFactory(GeneratorFactory): - """ - All generators that want an average and standard deviation. - """ + """All generators that want an average and standard deviation.""" def _get_generators_from_buckets( self, @@ -848,6 +885,7 @@ def _get_generators_from_buckets( def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -865,6 +903,8 @@ def get_generators( class LogNormalGenerator(Generator): + """Generator producing numbers in a log-normal distribution.""" + # TODO: figure out the real buckets here (this was from a random sample in R) expected_buckets = [ 0, @@ -887,6 +927,7 @@ def __init__( logmean: float, logstddev: float, ): + """Initialise a LogNormalGenerator.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -895,24 +936,29 @@ def __init__( self.logstddev = logstddev def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.lognormal" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', } def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "logmean": self.logmean, "logsd": self.logstddev, } def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" clauses = super().select_aggregate_clauses() return { **clauses, @@ -927,15 +973,14 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: } def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" if self.buckets is None: return default return self.buckets.fit_from_counts(self.expected_buckets) class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): - """ - All generators that want an average and standard deviation of log data. - """ + """All generators that want an average and standard deviation of log data.""" def _get_generators_from_buckets( self, @@ -968,8 +1013,8 @@ def _get_generators_from_buckets( def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: """ - Get a zipf distribution for a certain number of items distributed - in a certain number of bins. + Get a zipf distribution for a certain number of items. + :param total: The total number of items to be distributed. :param bins: The total number of bins to distribute the items into. :return: A generator of the number of items in each bin, from the @@ -989,6 +1034,8 @@ def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None class ChoiceGenerator(Generator): + """Base generator for all generators producing choices of items.""" + STORE_COUNTS = False def __init__( @@ -1000,6 +1047,7 @@ def __init__( sample_count: int | None = None, suppress_count: int = 0, ) -> None: + """Initialise a ChoiceGenerator.""" super().__init__() self.table_name = table_name self.column_name = column_name @@ -1035,27 +1083,29 @@ def __init__( @abstractmethod def get_estimated_counts(self, counts: list[int]) -> list[int]: - """ - The counts that we would expect if this distribution was the correct one. - """ + """Get the counts that we would expect if this distribution was the correct one.""" def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', } def name(self) -> str: + """Get the name of the generator.""" n = super().name() if self._annotation is None: return n return f"{n} [{self._annotation}]" def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "a": self.values, } def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" qs = super().custom_queries() return { **qs, @@ -1066,17 +1116,23 @@ def custom_queries(self) -> dict[str, dict[str, str]]: } def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default if self._fit is None else self._fit class ZipfChoiceGenerator(ChoiceGenerator): + """Generator producing items in a Zipf distribution.""" + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" return list(zipf_distribution(sum(counts), len(counts))) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.zipf_choice" def generate_data(self, count: int) -> list[float]: + """Generate ``count`` random data points for this column.""" return [ dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) ] @@ -1084,7 +1140,8 @@ def generate_data(self, count: int) -> list[float]: def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: """ - A generator putting ``total`` items uniformly into ``bins`` bins. + Construct a distribution putting ``total`` items uniformly into ``bins`` bins. + If they don't fit exactly evenly, the earlier bins will have one more item than the later bins so the total is as required. """ @@ -1097,30 +1154,36 @@ def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, N class UniformChoiceGenerator(ChoiceGenerator): - """ - A generator producing values, each roughly as frequently as each other. - """ + """A generator producing values, each roughly as frequently as each other.""" def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" return list(uniform_distribution(sum(counts), len(counts))) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.choice" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [dist_gen.choice(self.values) for _ in range(count)] class WeightedChoiceGenerator(ChoiceGenerator): + """Choice generator that matches the source data's frequency.""" + STORE_COUNTS = True def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" return counts def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.weighted_choice" def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [dist_gen.weighted_choice(self.values) for _ in range(count)] @@ -1143,6 +1206,7 @@ class ValueGatherer: """ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: + """Initialise a ValueGatherer.""" values = [] # All values found counts = [] # The number or each value cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" @@ -1176,9 +1240,7 @@ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: class ChoiceGeneratorFactory(GeneratorFactory): - """ - All generators that want an average and standard deviation. - """ + """All generators that want an average and standard deviation.""" SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 @@ -1186,6 +1248,7 @@ class ChoiceGeneratorFactory(GeneratorFactory): def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -1293,32 +1356,38 @@ def get_generators( class ConstantGenerator(Generator): + """Generator that always produces the same value.""" + def __init__(self, value: Any) -> None: + """Initialise the ConstantGenerator.""" super().__init__() self.value = value self.repr = repr(value) def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.constant" def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" return {"value": self.repr} def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" return {"value": self.value} def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" return [self.value for _ in range(count)] class ConstantGeneratorFactory(GeneratorFactory): - """ - Just the null generator - """ + """Just the null generator.""" def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators appropriate for these columns.""" if len(columns) != 1: return [] column = columns[0] @@ -1335,9 +1404,7 @@ def get_generators( class MultivariateNormalGenerator(Generator): - """ - Generator of multiple values drawn from a multivariate normal distribution. - """ + """Generator of multiple values drawn from a multivariate normal distribution.""" def __init__( self, @@ -1347,6 +1414,7 @@ def __init__( covariates: RowMapping, function_name: str, ) -> None: + """Initialise a MultivariateNormalGenerator.""" self._table = table_name self._columns = column_names self._query = query @@ -1354,14 +1422,17 @@ def __init__( self._function_name = function_name def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen." + self._function_name def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', } def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" cols = ", ".join(self._columns) return { f"auto__cov__{self._table}": { @@ -1371,32 +1442,34 @@ def custom_queries(self) -> dict[str, Any]: } def actual_kwargs(self) -> dict[str, Any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ + """Get the kwargs (summary statistics) this generator was instantiated with.""" return {"cov": self._covariates} def generate_data(self, count: int) -> list[Any]: - """ - Generate 'count' random data points for this column. - """ + """Generate 'count' random data points for this column.""" return [ getattr(dist_gen, self._function_name)(self._covariates) for _ in range(count) ] def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default class MultivariateNormalGeneratorFactory(GeneratorFactory): + """Normal distribution generator factory.""" + def function_name(self) -> str: + """Get the name of the generator function to call.""" return "multivariate_normal" def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" return column.name + " IS NOT NULL" def query_var(self, column: str) -> str: + """Get the SQL expression of the value to query for this column.""" return column def query( @@ -1411,7 +1484,8 @@ def query( sample_count: int | None = None, ) -> str: """ - Gets a query for the basics for multivariate normal/lognormal parameters. + Get a query for the basics for multivariate normal/lognormal parameters. + :param table: The name of the table to be queried. :param columns: The columns in the multivariate distribution. :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. @@ -1456,6 +1530,7 @@ def query( def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get the generators for these columns.""" # For the case of one column we'll use GaussianGenerator if len(columns) < 2: return [] @@ -1487,20 +1562,23 @@ def get_generators( class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Multivariate lognormal generator factory.""" + def function_name(self) -> str: + """Get the name of the generator function to call.""" return "multivariate_lognormal" def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" return f"COALESCE(0 < {column.name}, FALSE)" def query_var(self, column: str) -> str: + """Get the expression to query for, for this column.""" return f"LN({column})" def text_list(items: Iterable[str]) -> str: - """ - Concatenate the items with commas and one "and". - """ + """Concatenate the items with commas and one "and".""" item_i = iter(items) try: last_item = next(item_i) @@ -1518,6 +1596,8 @@ def text_list(items: Iterable[str]) -> str: @dataclass class RowPartition: + """A partition where all the rows have the same pattern of NULLs.""" + query: str # list of numeric columns included_numeric: list[Column] @@ -1537,6 +1617,7 @@ class RowPartition: covariates: Sequence[RowMapping] def comment(self) -> str: + """Make an appropriate comment for this partition.""" caveat = "" if self.included_choice: caveat = f" (for each possible value of {text_list(self.included_choice.values())})" @@ -1584,6 +1665,7 @@ def __init__( partition_counts: Iterable[RowMapping] = [], partition_count_comment: str | None = None, ): + """Initialise a NullPartitionedNormalGenerator.""" self._query_name = query_name self._partitions = partitions self._function_name = function_name @@ -1596,14 +1678,17 @@ def __init__( self._name = f"null-partitioned {function_name}" def name(self) -> str: + """Get the name of the generator.""" return self._name def function_name(self) -> str: + """Get the name of the generator function to call.""" return "dist_gen.alternatives" def _nominal_kwargs_with_combinations( self, index: int, partition: RowPartition ) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml`` for a single partition.""" count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' if not partition.included_numeric and not partition.included_choice: return { @@ -1634,6 +1719,7 @@ def _count_query_name(self) -> str: return f"auto__cov__{self._query_name}__counts" def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" return { "alternative_configs": [ self._nominal_kwargs_with_combinations(index, self._partitions[index]) @@ -1643,6 +1729,7 @@ def nominal_kwargs(self) -> dict[str, Any]: } def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" partitions = { f"auto__cov__{self._query_name}__alt_{index}": { "comment": partition.comment(), @@ -1690,9 +1777,7 @@ def _actual_kwargs_with_combinations( } def actual_kwargs(self) -> dict[str, Any]: - """ - The kwargs (summary statistics) this generator is instantiated with. - """ + """Get the kwargs (summary statistics) this generator was instantiated with.""" return { "alternative_configs": [ self._actual_kwargs_with_combinations(self._partitions[index]) @@ -1702,31 +1787,29 @@ def actual_kwargs(self) -> dict[str, Any]: } def generate_data(self, count: int) -> list[Any]: - """ - Generate 'count' random data points for this column. - """ + """Generate 'count' random data points for this column.""" kwargs = self.actual_kwargs() return [dist_gen.alternatives(**kwargs) for _ in range(count)] def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" return default def is_numeric(col: Column) -> bool: + """Test if this column stores a numeric value.""" ct = get_column_type(col) return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys def powerset(input: list[T]) -> Iterable[Iterable[T]]: - """Returns a list of all sublists of""" + """Get a list of all sublists of ``input``.""" return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) @dataclass class NullableColumn: - """ - A reference to a nullable column whose nullability is part of a partitioning. - """ + """A reference to a nullable column whose nullability is part of a partitioning.""" column: Column # The bit (power of two) of the number of the partition in the partition sizes list @@ -1734,13 +1817,12 @@ class NullableColumn: class NullPatternPartition: - """ - The definition of a partition (in other words, what makes it not another partition) - """ + """Get the definition of a partition (in other words, what makes it not another partition).""" def __init__( self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] ): + """Initialise a pattern of nulls which can be queried for.""" self.index = sum(nc.bitmask for nc in partition_nonnulls) nonnull_columns = {nc.column.name for nc in partition_nonnulls} self.included_numeric: list[Column] = [] @@ -1772,6 +1854,8 @@ def __init__( class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Produces null partitioned generators, for complex interdependent data.""" + SAMPLE_COUNT = MAXIMUM_CHOICES SUPPRESS_COUNT = 5 EMPTY_RESULT = [ @@ -1784,24 +1868,22 @@ class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): ] def function_name(self) -> str: + """Get the name of the generator function to call.""" return "grouped_multivariate_normal" def query_predicate(self, column: Column) -> str: - """ - Returns a SQL expression that is true when ``column`` is available for analysis. - """ + """Get a SQL expression that is true when ``column`` is available for analysis.""" if is_numeric(column): # x <> x + 1 ensures that x is not infinity or NaN return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" return f"{column.name} IS NOT NULL" def query_var(self, column: str) -> str: + """Return the expression we are querying for in this column.""" return column def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: - """ - Gets a list of nullable columns together with bitmasks. - """ + """Get a list of nullable columns together with bitmasks.""" out: list[NullableColumn] = [] for col in columns: if col.nullable: @@ -1817,7 +1899,7 @@ def get_partition_count_query( self, ncs: list[NullableColumn], table: str, where: str | None = None ) -> str: """ - Returns a SQL expression returning columns ``count`` and ``index``. + Get a SQL expression returning columns ``count`` and ``index``. Each row returned represents one of the null pattern partitions. ``index`` is the bitmask of all those nullable columns that are not null for @@ -1834,6 +1916,7 @@ def get_partition_count_query( def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: + """Get any appropriate generators for these columns.""" if len(columns) < 2: return [] nullable_columns = self.get_nullable_columns(columns) @@ -1934,6 +2017,7 @@ def _execute_partition_queries( ) -> bool: """ Execute the query in each partition, filling in the covariates. + :return: True if all the partitions work, False if any of them fail. """ found_nonzero = False @@ -1948,21 +2032,32 @@ def _execute_partition_queries( class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): + """ + A generator for numeric and non-numeric columns. + + Any values could be null, the distributions of the nonnull numeric columns + depend on each other and the other non-numeric column values. + """ + def function_name(self) -> str: + """Get the name of the generator function to call.""" return "grouped_multivariate_lognormal" def query_predicate(self, column: Column) -> str: + """Get the SQL expression testing if the value in this column should be used.""" if is_numeric(column): # x <> x + 1 ensures that x is not infinity or NaN return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" return f"{column.name} IS NOT NULL" def query_var(self, column: str) -> str: + """Get the variable or expression we are querying for this column.""" return f"LN({column})" @lru_cache(1) def everything_factory() -> GeneratorFactory: + """Get a factory that encapsulates all the other factories.""" return MultiGeneratorFactory( [ MimesisStringGeneratorFactory(), diff --git a/datafaker/interactive.py b/datafaker/interactive.py index 580fa3a..d35806d 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -1,3 +1,4 @@ +"""Interactive configuration commands.""" import cmd import csv import functools @@ -39,11 +40,13 @@ def or_default(v: T | None, d: T) -> T: - """Returns v if it isn't None, otherwise d.""" + """Return v if it isn't None, otherwise d.""" return d if v is None else v class TableType(Enum): + """Types of table to be configured.""" + GENERATE = "generate" IGNORE = "ignore" VOCABULARY = "vocabulary" @@ -70,38 +73,49 @@ class TableType(Enum): @dataclass class TableEntry: + """Base class for table entries for interactive commands.""" + name: str # name of the table class AskSaveCmd(cmd.Cmd): + """Interactive shell for whether to save and quit.""" + intro = "Do you want to save this configuration?" prompt = "(yes/no/cancel) " file = None def __init__(self) -> None: + """Initialise a save command.""" super().__init__() self.result = "" def do_yes(self, _arg: str) -> bool: + """Save the new config.yaml.""" self.result = "yes" return True def do_no(self, _arg: str) -> bool: + """Exit without saving.""" self.result = "no" return True def do_cancel(self, _arg: str) -> bool: + """Do not exit.""" self.result = "cancel" return True def fk_column_name(fk: ForeignKey) -> str: + """Display name for a foreign key.""" if fk_refers_to_ignored_table(fk): return f"{fk.target_fullname} (ignored)" return str(fk.target_fullname) class DbCmd(ABC, cmd.Cmd): + """Base class for interactive configuration commands.""" + INFO_NO_MORE_TABLES = "There are no more tables" ERROR_ALREADY_AT_START = "Error: Already at the start" ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" @@ -110,6 +124,13 @@ class DbCmd(ABC, cmd.Cmd): @abstractmethod def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | None: + """ + Make a table entry suitable for this interactive command. + + :param name: The name of the table to make an entry for. + :param table_config: The part of the ``config.yaml`` referring to this table. + :return: The table entry or None if this table should not be interacted with. + """ ... def __init__( @@ -119,6 +140,7 @@ def __init__( metadata: MetaData, config: MutableMapping[str, Any], ): + """Initialise a DbCmd.""" super().__init__() self.config: MutableMapping[str, Any] = config self.metadata = metadata @@ -138,9 +160,11 @@ def __init__( @property def sync_engine(self) -> Engine: + """Get the synchronous version of the engine.""" return get_sync_engine(self.engine) def __enter__(self) -> Self: + """Enter a ``with`` statement.""" return self def __exit__( @@ -149,12 +173,20 @@ def __exit__( _exc_val: Optional[BaseException], _exc_tb: Optional[TracebackType], ) -> None: + """Dispose of this object.""" self.engine.dispose() def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Print text, formatted with positional and keyword arguments.""" print(text.format(*args, **kwargs)) def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: + """ + Print a table. + + :param headings: List of headings for the table. + :param rows: List of rows of values. + """ output = PrettyTable() output.field_names = headings for row in rows: @@ -162,6 +194,11 @@ def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: print(output) def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: + """ + Print a table. + + :param columns: Dict of column names to the values in the column. + """ output = PrettyTable() row_count = max([len(col) for col in columns.values()]) for field_name, data in columns.items(): @@ -169,11 +206,13 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: print(output) def print_results(self, result: sqlalchemy.CursorResult) -> None: + """Print the rows resulting from a database query.""" self.print_table(list(result.keys()), [list(row) for row in result.all()]) def ask_save(self) -> str: """ Ask the user if they want to save. + :return: ``yes``, ``no`` or ``cancel``. """ ask = AskSaveCmd() @@ -182,11 +221,13 @@ def ask_save(self) -> str: @abstractmethod def set_prompt(self) -> None: + """Set the prompt according to the current state.""" ... def set_table_index(self, index: int) -> bool: """ Move to a different table. + :param index: Index of the table to move to. :return: True if there is a table with such an index to move to. """ @@ -198,7 +239,9 @@ def set_table_index(self, index: int) -> bool: def next_table(self, report: str = "No more tables") -> bool: """ - Move to the next table + Move to the next table. + + :param report: The text to print if there is no next table. :return: True if there is another table to move to. """ if not self.set_table_index(self.table_index + 1): @@ -211,12 +254,15 @@ def table_name(self) -> str: return str(self._table_entries[self.table_index].name) def table_metadata(self) -> Table: + """Get the metadata of the current table.""" return self.metadata.tables[self.table_name()] def get_column_names(self) -> list[str]: + """Get the names of the current columns.""" return [col.name for col in self.table_metadata().columns] def report_columns(self) -> None: + """Print information about the current columns.""" self.print_table( ["name", "type", "primary", "nullable", "foreign key"], [ @@ -232,6 +278,7 @@ def report_columns(self) -> None: ) def get_table_config(self, table_name: str) -> dict[str, Any]: + """Get the configuration of the named table.""" ts = self.config.get("tables", None) if type(ts) is not dict: return {} @@ -239,6 +286,7 @@ def get_table_config(self, table_name: str) -> dict[str, Any]: return t if type(t) is dict else {} def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: + """Set the configuration of the named table.""" ts = self.config.get("tables", None) if type(ts) is not dict: self.config["tables"] = {table_name: config} @@ -246,6 +294,7 @@ def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: ts[table_name] = config def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: + """Remove all source stats with the given prefix from the configuration.""" src_stats = self.config.get("src-stats", []) new_src_stats = [] for stat in src_stats: @@ -255,6 +304,7 @@ def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: return new_src_stats def get_nonnull_columns(self, table_name: str) -> list[str]: + """Get the names of the nullable columns in the named table.""" metadata_table = self.metadata.tables[table_name] return [ str(name) @@ -263,6 +313,7 @@ def get_nonnull_columns(self, table_name: str) -> list[str]: ] def find_entry_index_by_table_name(self, table_name: str) -> int | None: + """Get the index of the table entry of the named table.""" return next( ( i @@ -273,13 +324,14 @@ def find_entry_index_by_table_name(self, table_name: str) -> int | None: ) def find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + """Get the table entry of the named table.""" for e in self._table_entries: if e.name == table_name: return e return None def do_counts(self, _arg: str) -> None: - "Report the column names with the counts of nulls in them" + """Report the column names with the counts of nulls in them.""" if len(self._table_entries) <= self.table_index: return table_name = self.table_name() @@ -309,7 +361,7 @@ def do_counts(self, _arg: str) -> None: ) def do_select(self, arg: str) -> None: - "Run a select query over the database and show the first 50 results" + """Run a select query over the database and show the first 50 results.""" MAX_SELECT_ROWS = 50 with self.sync_engine.connect() as connection: try: @@ -327,6 +379,8 @@ def do_select(self, arg: str) -> None: def do_peek(self, arg: str) -> None: """ + View some data from the current table. + Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. Use 'peek' to see a sample of the current column(s). Rows that are enitrely null are suppressed. @@ -358,6 +412,7 @@ def do_peek(self, arg: str) -> None: def complete_peek( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Completions for the ``peek`` command.""" if len(self._table_entries) <= self.table_index: return [] return [ @@ -367,11 +422,15 @@ def complete_peek( @dataclass class TableCmdTableEntry(TableEntry): + """Table entry for the table command shell.""" + old_type: TableType new_type: TableType class TableCmd(DbCmd): + """Command shell allowing the user to set the type of each table.""" + intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" doc_leader = """Use the commands 'ignore', 'vocabulary', 'private', 'empty' or 'generate' to set the table's type. Use 'next' or @@ -395,6 +454,13 @@ class TableCmd(DbCmd): NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | None: + """ + Make a table entry for the named table. + + :param name: The name of the table. + :param table: The part of ``config.yaml`` corresponding to this table. + :return: The newly-constructed table entry. + """ if table.get("ignore", False): return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) if table.get("vocabulary_table", False): @@ -412,20 +478,24 @@ def __init__( metadata: MetaData, config: MutableMapping[str, Any], ) -> None: + """Initialise a TableCmd.""" super().__init__(src_dsn, src_schema, metadata, config) self.set_prompt() @property def table_entries(self) -> list[TableCmdTableEntry]: + """Get the list of table entries.""" return cast(list[TableCmdTableEntry], self._table_entries) def find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + """Get the table entry of the table with the given name.""" entry = super().find_entry_by_table_name(table_name) if entry is None: return None return cast(TableCmdTableEntry, entry) def set_prompt(self) -> None: + """Set the prompt according to the current table and its type.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) @@ -433,11 +503,13 @@ def set_prompt(self) -> None: self.prompt = "(table) " def set_type(self, t_type: TableType) -> None: + """Set the type of the current table.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type def _copy_entries(self) -> None: + """Alter the configuration to match the new table entries.""" for entry in self.table_entries: if entry.old_type != entry.new_type: table = self.get_table_config(entry.name) @@ -470,6 +542,7 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table) def _get_referenced_tables(self, from_table_name: str) -> set[str]: + """Get all the tables referenced by this table's foreign keys.""" from_meta = self.metadata.tables[from_table_name] return { fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys @@ -520,7 +593,7 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: return warnings def do_quit(self, _arg: str) -> bool: - "Check the updates, save them if desired and quit the configurer." + """Check the updates, save them if desired and quit the configurer.""" count = 0 for entry in self.table_entries: if entry.old_type != entry.new_type: @@ -552,7 +625,7 @@ def do_quit(self, _arg: str) -> bool: return False def do_tables(self, _arg: str) -> None: - "list the tables with their types" + """List the tables with their types.""" for entry in self.table_entries: old = entry.old_type new = entry.new_type @@ -560,7 +633,7 @@ def do_tables(self, _arg: str) -> None: self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) def do_next(self, arg: str) -> None: - "'next' = go to the next table, 'next tablename' = go to table 'tablename'" + """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" if arg: # Find the index of the table called _arg, if any index = self.find_entry_index_by_table_name(arg) @@ -574,52 +647,54 @@ def do_next(self, arg: str) -> None: def complete_next( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Get the completions for tables and columns.""" return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] def do_previous(self, _arg: str) -> None: - "Go to the previous table" + """Go to the previous table.""" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def do_ignore(self, _arg: str) -> None: - "Set the current table as ignored, and go to the next table" + """Set the current table as ignored, and go to the next table.""" self.set_type(TableType.IGNORE) self.print("Table {} set as ignored", self.table_name()) self.next_table() def do_vocabulary(self, _arg: str) -> None: - "Set the current table as a vocabulary table, and go to the next table" + """Set the current table as a vocabulary table, and go to the next table.""" self.set_type(TableType.VOCABULARY) self.print("Table {} set to be a vocabulary table", self.table_name()) self.next_table() def do_private(self, _arg: str) -> None: - "Set the current table as a primary private table (such as the table of patients)" + """Set the current table as a primary private table (such as the table of patients).""" self.set_type(TableType.PRIVATE) self.print("Table {} set to be a primary private table", self.table_name()) self.next_table() def do_generate(self, _arg: str) -> None: - "Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table" + """Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table.""" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() def do_empty(self, _arg: str) -> None: - "Set the current table as empty; no generators will be run for it" + """Set the current table as empty; no generators will be run for it.""" self.set_type(TableType.EMPTY) self.print("Table {} empty", self.table_name()) self.next_table() def do_columns(self, _arg: str) -> None: - "Report the column names and metadata" + """Report the column names and metadata.""" self.report_columns() def do_data(self, arg: str) -> None: """ Report some data. + 'data' = report a random ten lines, 'data 20' = report a random 20 lines, 'data 20 ColumnName' = report a random twenty entries from ColumnName, @@ -660,6 +735,7 @@ def do_data(self, arg: str) -> None: def complete_data( self, text: str, line: str, begidx: int, _endidx: int ) -> list[str]: + """Get completions for arguments to ``data``.""" previous_parts = line[: begidx - 1].split() if len(previous_parts) != 2: return [] @@ -667,6 +743,13 @@ def complete_data( return [k for k in table_metadata.columns.keys() if k.startswith(text)] def print_column_data(self, column: str, count: int, min_length: int) -> None: + """ + Print a sample of data from a certain column of the current table. + + :param column: The name of the column to report on. + :param count: The number of rows to sample. + :param min_length: The minimum length of text to choose from (0 for any text). + """ where = f"WHERE {column} IS NOT NULL" if 0 < min_length: where = "WHERE LENGTH({column}) >= {len}".format( @@ -687,6 +770,11 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: self.columnize([str(x[0]) for x in result.all()]) def print_row_data(self, count: int) -> None: + """ + Print a sample or rows from the current table. + + :param count: The number of rows to report. + """ with self.sync_engine.connect() as connection: result = connection.execute( text( @@ -705,6 +793,7 @@ def print_row_data(self, count: int) -> None: def update_config_tables( src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping ) -> Mapping[str, Any]: + """Ask the user to specify what should happen to each table.""" with TableCmd(src_dsn, src_schema, metadata, config) as tc: tc.cmdloop() return tc.config @@ -712,6 +801,8 @@ def update_config_tables( @dataclass class MissingnessType: + """The functions required for applying missingness.""" + SAMPLED = "column_presence.sampled" SAMPLED_QUERY = ( "SELECT COUNT(*) AS row_count, {result_names} FROM " @@ -725,6 +816,14 @@ class MissingnessType: @classmethod def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: + """ + Construct a query to make a sampling of the named rows of the table. + + :param table: The name of the table to sample. + :param count: The number of samples to get. + :param column_names: The columns to fetch. + :return: The SQL query to do the sampling. + """ result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) column_is_nulls = ", ".join( ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] @@ -739,11 +838,19 @@ def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> s @dataclass class MissingnessCmdTableEntry(TableEntry): + """Table entry for the missingness command shell.""" + old_type: MissingnessType new_type: MissingnessType | None class MissingnessCmd(DbCmd): + """ + Interactive shell for the user to set missingness. + + Can only be used for Missingness Completely At Random. + """ + intro = "Interactive missingness configuration. Type ? for help.\n" doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for the current table. Use commands 'next' and 'previous' to change the @@ -774,6 +881,13 @@ def find_missingness_query( def make_table_entry( self, name: str, table: Mapping ) -> MissingnessCmdTableEntry | None: + """ + Make a table entry for a particular table. + + :param name: The name of the table to make an entry for. + :param table: The part of ``config.yaml`` relating to this table. + :return: The newly-constructed table entry. + """ if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -822,6 +936,7 @@ def __init__( ): """ Initialise a MissingnessCmd. + :param src_dsn: connection string for the source database. :param src_schema: schema name for the source database. :param metadata: SQLAlchemy metadata for the source database. @@ -832,20 +947,20 @@ def __init__( @property def table_entries(self) -> list[MissingnessCmdTableEntry]: + """Get the table entries list.""" return cast(list[MissingnessCmdTableEntry], self._table_entries) def find_entry_by_table_name( self, table_name: str ) -> MissingnessCmdTableEntry | None: + """Find the table entry given the table name.""" entry = super().find_entry_by_table_name(table_name) if entry is None: return None return cast(MissingnessCmdTableEntry, entry) def set_prompt(self) -> None: - """ - Sets the prompt according to the current table and missingness. - """ + """Set the prompt according to the current table and missingness.""" if self.table_index < len(self.table_entries): entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] nt = entry.new_type @@ -857,11 +972,13 @@ def set_prompt(self) -> None: self.prompt = "(missingness) " def set_type(self, t_type: MissingnessType) -> None: + """Set the missingness of the current table.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] entry.new_type = t_type def _copy_entries(self) -> None: + """Set the new missingness into the configuration.""" src_stats = self._remove_prefix_src_stats("missing_auto__") for entry in self.table_entries: table = self.get_table_config(entry.name) @@ -892,7 +1009,7 @@ def _copy_entries(self) -> None: self.set_table_config(entry.name, table) def do_quit(self, _arg: str) -> bool: - "Check the updates, save them if desired and quit the configurer." + """Check the updates, save them if desired and quit the configurer.""" count = 0 for entry in self.table_entries: if entry.old_type != entry.new_type: @@ -927,7 +1044,7 @@ def do_quit(self, _arg: str) -> bool: return False def do_tables(self, _arg: str) -> None: - "list the tables with their types" + """List the tables with their types.""" for entry in self.table_entries: old = "-" if entry.old_type is None else entry.old_type.name new = "-" if entry.new_type is None else entry.new_type.name @@ -935,7 +1052,11 @@ def do_tables(self, _arg: str) -> None: self.print("{0} {1}", entry.name, desc) def do_next(self, arg: str) -> None: - "'next' = go to the next table, 'next tablename' = go to table 'tablename'" + """ + Go to the next table, or a specified table. + + 'next' = go to the next table, 'next tablename' = go to table 'tablename' + """ if arg: # Find the index of the table called _arg, if any index = next( @@ -952,19 +1073,18 @@ def do_next(self, arg: str) -> None: def complete_next( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Get completions for tables and columns.""" return [ entry.name for entry in self.table_entries if entry.name.startswith(text) ] def do_previous(self, _arg: str) -> None: - "Go to the previous table" + """Go to the previous table.""" if not self.set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def _set_type(self, name: str, query: str, comment: str) -> None: - """ - Set the current table entry's query. - """ + """Set the current table entry's query.""" if len(self.table_entries) <= self.table_index: return entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] @@ -976,9 +1096,7 @@ def _set_type(self, name: str, query: str, comment: str) -> None: ) def _set_none(self) -> None: - """ - Sets the current table to have no missingness applied. - """ + """Set the current table to have no missingness applied.""" if len(self.table_entries) <= self.table_index: return self.table_entries[self.table_index].new_type = None @@ -986,9 +1104,10 @@ def _set_none(self) -> None: def do_sampled(self, arg: str) -> None: """ Set the current table missingness as 'sampled', and go to the next table. - "sampled 3000" means sample 3000 rows at random and choose the missingness - to be the same as one of those 3000 at random. - "sampled" means the same, but with a default number of rows sampled (1000). + + 'sampled 3000' means sample 3000 rows at random and choose the + missingness to be the same as one of those 3000 at random. + 'sampled' means the same, but with a default number of rows sampled (1000). """ if len(self.table_entries) <= self.table_index: self.print("Error! not on a table") @@ -1017,7 +1136,7 @@ def do_sampled(self, arg: str) -> None: self.next_table() def do_none(self, _arg: str) -> None: - "Set the current table to have no missingness, and go to the next table" + """Set the current table to have no missingness, and go to the next table.""" self._set_none() self.print("Table {} set to have no missingness", self.table_name()) self.next_table() @@ -1029,6 +1148,16 @@ def update_missingness( metadata: MetaData, config: MutableMapping[str, Any], ) -> Mapping[str, Any]: + """ + Ask the user to update the missingness information in ``config.yaml``. + + :param src_dsn: The connection string for the source database. + :param src_schema: The name of the source database schema (or None + for the default). + :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. + :param config: The starting configuration, + :return: The updated configuration. + """ with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: mc.cmdloop() return mc.config @@ -1036,9 +1165,7 @@ def update_missingness( @dataclass class GeneratorInfo: - """ - A generator and the columns it assigns to. - """ + """A generator and the columns it assigns to.""" columns: list[str] gen: Generator | None @@ -1048,6 +1175,7 @@ class GeneratorInfo: class GeneratorCmdTableEntry(TableEntry): """ List of generators set for a table. + Includes the original setting and the currently configured generators. """ @@ -1057,9 +1185,7 @@ class GeneratorCmdTableEntry(TableEntry): class GeneratorCmd(DbCmd): - """ - Interactive command shell for setting generators. - """ + """Interactive command shell for setting generators.""" intro = "Interactive generator configuration. Type ? for help.\n" doc_leader = """Use command 'propose' for a list of generators applicable to the @@ -1095,6 +1221,13 @@ class GeneratorCmd(DbCmd): def make_table_entry( self, table_name: str, table: Mapping ) -> GeneratorCmdTableEntry | None: + """ + Make a table entry. + + :param table_name: The name of the table. + :param table: The portion of the ``config.yaml`` file describing this table. + :return: The newly constructed table entry, or None if this table is to be ignored. + """ if table.get("ignore", False): return None if table.get("vocabulary_table", False): @@ -1169,7 +1302,8 @@ def __init__( config: MutableMapping[str, Any], ) -> None: """ - Initialise a GeneratorCmd + Initialise a ``GeneratorCmd``. + :param src_dsn: connection address for source database :param src_schema: database schema name :param metadata: SQLAlchemy metadata for the source database @@ -1182,11 +1316,18 @@ def __init__( @property def table_entries(self) -> list[GeneratorCmdTableEntry]: + """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" return cast(list[GeneratorCmdTableEntry], self._table_entries) def find_entry_by_table_name( self, table_name: str ) -> GeneratorCmdTableEntry | None: + """ + Find the table entry by name. + + :param table_name: The name of the table to find. + :return: The table entry, or None if no such table name exists. + """ entry = super().find_entry_by_table_name(table_name) if entry is None: return None @@ -1194,7 +1335,8 @@ def find_entry_by_table_name( def set_table_index(self, index: int) -> bool: """ - Moves to a new table. + Move to a new table. + :param index: table index to move to. """ ret = super().set_table_index(index) @@ -1206,6 +1348,7 @@ def set_table_index(self, index: int) -> bool: def previous_table(self) -> bool: """ Move to the table before the current one. + :return: True if there is a previous table to go to. """ ret = self.set_table_index(self.table_index - 1) @@ -1223,17 +1366,13 @@ def previous_table(self) -> bool: return ret def get_table(self) -> GeneratorCmdTableEntry | None: - """ - Get the current table entry. - """ + """Get the current table entry.""" if self.table_index < len(self.table_entries): return self.table_entries[self.table_index] return None def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: - """ - Gets a pair; the table name then the generator information. - """ + """Get a pair; the table name then the generator information.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] if self.generator_index < len(entry.new_generators): @@ -1242,25 +1381,19 @@ def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: return (None, None) def get_column_names(self) -> list[str]: - """ - Gets the (unqualified) names for all the current columns. - """ + """Get the (unqualified) names for all the current columns.""" (_, generator_info) = self.get_table_and_generator() return generator_info.columns if generator_info else [] def column_metadata(self) -> list[Column]: - """ - Gets the metadata for all the current columns. - """ + """Get the metadata for all the current columns.""" table = self.table_metadata() if table is None: return [] return [table.columns[name] for name in self.get_column_names()] def set_prompt(self) -> None: - """ - Set the prompt according to the current table, column and generator. - """ + """Set the prompt according to the current table, column and generator.""" (table_name, gen_info) = self.get_table_and_generator() if table_name is None: self.prompt = "(generators) " @@ -1277,14 +1410,15 @@ def set_prompt(self) -> None: def _remove_auto_src_stats(self) -> list[dict[str, Any]]: """ - Remove all automatic source stats (which we assume is - every source stats query whose name begins with ``auto__`)""" + Remove all automatic source stats. + + We assume every source stats query whose name begins with ``auto__` + :return: The new ``src_stats`` configuration. + """ return self._remove_prefix_src_stats("auto__") def _copy_entries(self) -> None: - """ - Set generator and query information in the configuration. - """ + """Set generator and query information in the configuration.""" src_stats = self._remove_auto_src_stats() for entry in self.table_entries: rgs = [] @@ -1344,7 +1478,7 @@ def _find_old_generator( return None def do_quit(self, arg: str) -> bool: - "Check the updates, save them if desired and quit the configurer." + """Check the updates, save them if desired and quit the configurer.""" count = 0 for entry in self.table_entries: header_shown = False @@ -1377,7 +1511,7 @@ def do_quit(self, arg: str) -> bool: return False def do_tables(self, arg: str) -> None: - "list the tables" + """List the tables.""" for t_entry in self.table_entries: entry = cast(GeneratorCmdTableEntry, t_entry) gen_count = len(entry.new_generators) @@ -1385,7 +1519,7 @@ def do_tables(self, arg: str) -> None: self.print("{0} ({1})", entry.name, how_many) def do_list(self, arg: str) -> None: - "list the generators in the current table" + """List the generators in the current table.""" if len(self.table_entries) <= self.table_index: self.print("Error: no table {0}", self.table_index) return @@ -1408,11 +1542,11 @@ def do_list(self, arg: str) -> None: self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) def do_columns(self, _arg: str) -> None: - "Report the column names and metadata" + """Report the column names and metadata.""" self.report_columns() def do_info(self, _arg: str) -> None: - "Show information about the current column" + """Show information about the current column.""" for cm in self.column_metadata(): self.print( "Column {0} in table {1} has type {2} ({3}).", @@ -1436,12 +1570,21 @@ def do_info(self, _arg: str) -> None: ) def _get_table_index(self, table_name: str) -> int | None: + """Get the index of the named table in the table entries list.""" for n, entry in enumerate(self.table_entries): if entry.name == table_name: return n return None def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + """ + Get the index number of a column within the list of generators in this table. + + :param table_index: The index of the table in which to search. + :param column_name: The name of the column to search for. + :return: The index in the ``new_generators`` attribute of the table entry + containing the specified column, or None if this does not exist. + """ entry = self.table_entries[table_index] for n, gen in enumerate(entry.new_generators): if column_name in gen.columns: @@ -1449,6 +1592,11 @@ def _get_generator_index(self, table_index: int, column_name: str) -> int | None return None def go_to(self, target: str) -> bool: + """ + Go to a particular column. + + :return: True on success. + """ parts = target.split(".", 1) table_index = self._get_table_index(parts[0]) if table_index is None: @@ -1474,10 +1622,11 @@ def go_to(self, target: str) -> bool: def do_next(self, arg: str) -> None: """ - Go to the next generator. - Or go to a named table: 'next tablename'. - Or go to a column: 'next tablename.columnname'. - Or go to a column within this table: 'next columnname'. + Go to the next generator. or a specified generator. + + Go to a named table: 'next tablename', + go to a column: 'next tablename.columnname', + or go to a column within this table: 'next columnname'. """ if arg: self.go_to(arg) @@ -1485,13 +1634,15 @@ def do_next(self, arg: str) -> None: self._go_next() def do_n(self, arg: str) -> None: - """Synonym for next""" + """Go to the next generator, or a specified generator.""" self.do_next(arg) def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: + """Complete the ``n`` command's arguments.""" return self.complete_next(text, line, begidx, endidx) def _go_next(self) -> None: + """Go to the next column.""" table = self.get_table() if table is None: self.print("No more tables") @@ -1506,6 +1657,7 @@ def _go_next(self) -> None: def complete_next( self, text: str, _line: str, _begidx: int, _endidx: int ) -> list[str]: + """Completions for the arguments of the ``next`` command.""" parts = text.split(".", 1) first_part = parts[0] if 1 < len(parts): @@ -1540,7 +1692,7 @@ def complete_next( return table_names + column_names def do_previous(self, _arg: str) -> None: - """Go to the previous generator""" + """Go to the previous generator.""" if self.generator_index == 0: self.previous_table() else: @@ -1548,23 +1700,18 @@ def do_previous(self, _arg: str) -> None: self.set_prompt() def do_b(self, arg: str) -> None: - """Synonym for previous""" + """Synonym for previous.""" self.do_previous(arg) def _generators_valid(self) -> bool: - """ - Return True if the self.generators property is still correct for the - table and columns currently being examined. - """ + """Test if ``self.generators`` is still correct for the current columns.""" return self.generators_valid_columns == ( self.table_index, self.get_column_names(), ) def _get_generator_proposals(self) -> list[Generator]: - """ - Get a list of acceptable generators, sorted by decreasing fit to the actual data. - """ + """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" if not self._generators_valid(): self.generators = None if self.generators is None: @@ -1579,9 +1726,7 @@ def _get_generator_proposals(self) -> list[Generator]: return self.generators def _print_privacy(self) -> None: - """ - Print the privacy status of the current table. - """ + """Print the privacy status of the current table.""" table = self.table_metadata() if table is None: return @@ -1630,6 +1775,10 @@ def do_c(self, arg: str) -> None: def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: """ Print the values queried from the database for this generator. + + :param table_name: The name of the table the generator applies to. + :param n: A number to print at the start of the output. + :param gen: The generator to report. """ if not gen.select_aggregate_clauses() and not gen.custom_queries(): self.print( @@ -1649,6 +1798,8 @@ def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None def _print_custom_queries(self, gen: Generator) -> None: """ Print all the custom queries and all the values they get in this case. + + :param gen: The generator to print the custom queries for. """ cqs = gen.custom_queries() if not cqs: @@ -1705,7 +1856,12 @@ def _get_aggregate_query( def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: """ - Prints the select aggregate query and all the values it gets in this case. + Print the select aggregate query and all the values it gets in this case. + + This is not the entire query that will be executed, but only the part of it + that is required by a certain generator. + :param table_name: The table name. + :param gen: The generator to limit the aggregate query to. """ sacs = gen.select_aggregate_clauses() if not sacs: @@ -1793,10 +1949,7 @@ def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: return None def do_set(self, arg: str) -> None: - """ - Set one of the proposals as a generator. - :param arg: A single integer (as a string). - """ + """Set one of the proposals as a generator.""" if arg.isdigit() and not self._generators_valid(): self.print("Please run 'propose' before 'set '") return @@ -1847,7 +2000,6 @@ def do_merge(self, arg: str) -> None: Add this column(s) to the specified column(s). After this, one generator will cover them all. - :param arg: space separated list of column names to merge. """ cols = arg.split() if not cols: @@ -1980,6 +2132,7 @@ def update_config_generators( ) -> Mapping[str, Any]: """ Update configuration with the specification from a CSV file. + The specification is a headerless CSV file with columns: Table name, Column name (or space-separated list of column names), Generator name required, Second choice generator name, Third choice generator diff --git a/datafaker/main.py b/datafaker/main.py index 454cf44..42f6bde 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -70,9 +70,22 @@ def _require_src_db_dsn(settings: Settings) -> str: return src_dsn -def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: - with open(orm_file_name) as orm_fh: +def load_metadata_config( + orm_file_name: str, config: dict | None = None +) -> dict[str, Any]: + """ + Load the ``orm.yaml`` file, returning a dict representation. + + :param orm_file_name: The name of the file to load. + :param config: The ``config.yaml`` file object. Ignored tables will be + excluded from the output. + :return: A dict representing the ``orm.yaml`` file, with the tables + the ``config`` says to ignore removed. + """ + with open(orm_file_name, encoding="utf-8") as orm_fh: meta_dict = yaml.load(orm_fh, yaml.Loader) + if not isinstance(meta_dict, dict): + return {} tables_dict = meta_dict.get("tables", {}) if config is not None and "tables" in config: # Remove ignored tables @@ -84,7 +97,8 @@ def load_metadata_config(orm_file_name: str, config: dict | None = None) -> Any: def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: """ - Load metadata from ``orm.yaml`` + Load metadata from ``orm.yaml``. + :param orm_file_name: ``orm.yaml`` or alternative name to load metadata from. :param config: Used to exclude tables that are marked as ``ignore: true``. :return: SQLAlchemy MetaData object representing the database described by the loaded file. @@ -94,9 +108,7 @@ def load_metadata(orm_file_name: str, config: dict | None = None) -> MetaData: def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> Any: - """ - Load metadata excluding any foreign keys pointing to ignored tables. - """ + """Load metadata excluding any foreign keys pointing to ignored tables.""" meta_dict = load_metadata_config(orm_file_name, config) return dict_to_metadata(meta_dict, config) @@ -105,6 +117,7 @@ def load_metadata_for_output(orm_file_name: str, config: dict | None = None) -> def main( verbose: bool = Option(False, "--verbose", "-v", help="Print more information.") ) -> None: + """Set the global parameters.""" conf_logger(verbose) @@ -327,7 +340,10 @@ def make_stats( def make_tables( config_file: Optional[str] = Option( None, - help="The configuration file, used if you want an orm.yaml lacking data for the ignored tables", + help=( + "The configuration file, used if you want" + " an orm.yaml lacking data for the ignored tables" + ), ), orm_file: str = Option(ORM_FILENAME, help="Path to write the ORM yaml file to"), force: bool = Option( @@ -361,9 +377,7 @@ def configure_tables( ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: - """ - Interactively set tables to ignored, vocabulary or primary private. - """ + """Interactively set tables to ignored, vocabulary or primary private.""" logger.debug("Configuring tables in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -393,9 +407,7 @@ def configure_missing( ), orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), ) -> None: - """ - Interactively set the missingness of the generated data. - """ + """Interactively set the missingness of the generated data.""" logger.debug("Configuring missingness in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -405,7 +417,7 @@ def configure_missing( config_any = yaml.load( config_file_path.read_text(encoding="UTF-8"), Loader=yaml.SafeLoader ) - if type(config_any) is dict: + if isinstance(config_any, dict): config = config_any metadata = load_metadata(orm_file, config) config_updated = update_missingness(src_dsn, settings.src_schema, metadata, config) @@ -428,9 +440,7 @@ def configure_generators( help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", ), ) -> None: - """ - Interactively set generators for column data. - """ + """Interactively set generators for column data.""" logger.debug("Configuring generators in %s.", config_file) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) @@ -468,7 +478,7 @@ def dump_data( schema_name = settings.dst_schema config = read_config_file(config_file) if config_file is not None else {} metadata = load_metadata_for_output(orm_file, config) - if output == None: + if output is None: if isinstance(sys.stdout, io.TextIOBase): dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return @@ -555,6 +565,8 @@ def remove_tables( class TableType(str, Enum): + """Types of tables for the ``list-tables`` command.""" + ALL = "all" VOCAB = "vocab" GENERATED = "generated" diff --git a/datafaker/make.py b/datafaker/make.py index e9db863..096ee8b 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -75,7 +75,7 @@ class RowGeneratorInfo: @dataclass class ColumnChoice: - """Chooses columns based on a random number in [0,1)""" + """Choose columns based on a random number in [0,1).""" function_name: str argument_values: list[str] @@ -84,6 +84,14 @@ class ColumnChoice: def make_column_choices( table_config: Mapping[str, Any], ) -> list[ColumnChoice]: + """ + Convert ``missingness_generators`` from ``config.yaml`` into functions to call. + + :param table_config: The ``tables`` part of ``config.yaml``. + :return: A list of ``ColumnChoice`` objects; that is, descriptions of + functions and their arguments to call to reveal a list of columns that + should have values generated for them. + """ return [ ColumnChoice( function_name=mg["name"], @@ -97,9 +105,9 @@ def make_column_choices( @dataclass class _PrimaryConstraint: """ - Describes a Uniqueness constraint for when multiple - columns in a table comprise the primary key. Not a - real constraint, but enough to write df.py. + Describes a Uniqueness constraint for a multi-column primary key. + + Not a real constraint, but enough to write df.py. """ columns: list[Column] @@ -252,8 +260,10 @@ def _get_default_generator(column: Column) -> RowGeneratorInfo: def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of a generator and maybe arguments - that limit its range to the permitted scale. + Get the default generator name and arguments. + + :param column: The column to get the generator for. + :return: The name of a generator and its arguments. """ column_type = column.type scale = getattr(column_type, "scale", None) @@ -270,8 +280,10 @@ def _numeric_generator(column: Column) -> tuple[str, dict[str, str]]: def _string_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of a string generator and maybe arguments - that limit its length. + Get the name of the default string generator for a column. + + :param column: The column to get the generator for. + :return: The name of the generator and its arguments. """ column_size: Optional[int] = getattr(column.type, "length", None) if column_size is None: @@ -281,7 +293,11 @@ def _string_generator(column: Column) -> tuple[str, dict[str, str]]: def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: """ - Returns the name of an integer generator. + Get the name of the default integer generator. + + :param column: The column to get the generator for. + :return: A pair consisting of the name of a generator and its + arguments. """ if not column.primary_key: return ("generic.numeric.integer_number", {}) @@ -302,6 +318,8 @@ def _integer_generator(column: Column) -> tuple[str, dict[str, str]]: @dataclass class GeneratorInfo: + """Description of a generator.""" + # Name or function to generate random objects of this type (not using summary data) generator: str | Callable[[Column], tuple[str, dict[str, str]]] # SQL query that gets the data to supply as arguments to the generator @@ -321,8 +339,9 @@ def get_result_mappings( info: GeneratorInfo, results: CursorResult ) -> dict[str, Any] | None: """ - Gets a mapping from the results of a database query as a Python - dictionary converted according to the GeneratorInfo provided. + Get a mapping from the results of a database query. + + :return: A Python dictionary converted according to the GeneratorInfo provided. """ kw: dict[str, Any] = {} mapping = results.mappings().first() @@ -379,7 +398,7 @@ def get_result_mappings( def _get_info_for_column_type(column_t: type) -> GeneratorInfo | None: """ - Gets a generator from a column type. + Get a generator from a column type. Returns either a string representing the callable, or a callable that, given the column.type will return a tuple (string representing generator @@ -400,7 +419,7 @@ def _get_generator_for_column( column_t: type, ) -> str | Callable[[Column], tuple[str, dict[str, str]]] | None: """ - Gets a generator from a column type. + Get a generator from a column type. Returns either a string representing the callable, or a callable that, given the column.type will return a tuple (string representing generator @@ -412,8 +431,9 @@ def _get_generator_for_column( def _get_generator_and_arguments(column: Column) -> tuple[str | None, dict[str, str]]: """ - Gets the generator and its arguments from the column type, returning - a tuple of a string representing the generator callable and a dict of + Get the generator and its arguments from the column type. + + :return: A tuple of a string representing the generator callable and a dict of keyword arguments to supply to it. """ generator_function = _get_generator_for_column(type(column.type)) @@ -540,10 +560,7 @@ def make_vocabulary_tables( compress: bool, table_names: set[str] | None = None, ) -> None: - """ - Extracts the data from the source database for each - vocabulary table. - """ + """Extract the data from the source database for each vocabulary table.""" settings = get_settings() src_dsn: str = settings.src_dsn or "" assert src_dsn != "", "Missing SRC_DSN setting." @@ -660,9 +677,7 @@ def generate_df_content(template_context: Mapping[str, Any]) -> str: def _get_generator_for_existing_vocabulary_table( table: Table, ) -> VocabularyTableGeneratorInfo: - """ - Turns an existing vocabulary YAML file into a VocabularyTableGeneratorInfo. - """ + """Turn an existing vocabulary YAML file into a VocabularyTableGeneratorInfo.""" return VocabularyTableGeneratorInfo( dictionary_entry=table.name, variable_name=f"{table.name.lower()}_vocab", @@ -676,9 +691,7 @@ def _generate_vocabulary_table( overwrite_files: bool = False, compress: bool = False, ) -> None: - """ - Pulls data out of the source database to make a vocabulary YAML file - """ + """Pull data out of the source database to make a vocabulary YAML file.""" yaml_file_name: str = table.fullname + ".yaml" if compress: yaml_file_name += ".gz" @@ -692,9 +705,7 @@ def _generate_vocabulary_table( def make_tables_file( db_dsn: str, schema_name: Optional[str], config: Mapping[str, Any] ) -> str: - """ - Construct the YAML file representing the schema. - """ + """Construct the YAML file representing the schema.""" tables_config = config.get("tables", {}) engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) @@ -726,6 +737,8 @@ def reflect_if(table_name: str, _: Any) -> bool: class DbConnection: + """A connection to a database.""" + def __init__(self, engine: MaybeAsyncEngine) -> None: """ Initialise an unopened database connection. @@ -736,6 +749,7 @@ def __init__(self, engine: MaybeAsyncEngine) -> None: self._connection: Connection | AsyncConnection async def __aenter__(self) -> Self: + """Enter the ``with`` section, opening a connection.""" if isinstance(self._engine, AsyncEngine): self._connection = await self._engine.connect() else: @@ -748,16 +762,19 @@ async def __aexit__( _value: Optional[BaseException], _tb: Optional[TracebackType], ) -> None: + """Exit the ``with`` section, closing the connection.""" if isinstance(self._connection, AsyncConnection): await self._connection.close() self._connection.close() async def execute_raw_query(self, query: Executable) -> CursorResult: + """Execute the query on the owned connection.""" if isinstance(self._connection, AsyncConnection): return await self._connection.execute(query) return self._connection.execute(query) async def table_row_count(self, table_name: str) -> int: + """Count the number of rows in the named table.""" with await self.execute_raw_query( text(f"SELECT COUNT(*) FROM {table_name}") ) as result: @@ -790,12 +807,14 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: def fix_type(value: Any) -> Any: + """Make this value suitable for yaml output.""" if type(value) is decimal.Decimal: return float(value) return value def fix_types(dics: list[dict]) -> list[dict]: + """Make all the items in this list suitable for yaml output.""" return [{k: fix_type(v) for k, v in dic.items()} for dic in dics] @@ -827,6 +846,7 @@ async def make_src_stats_connection( ) -> dict[str, dict[str, Any]]: """ Make the ``src-stats.yaml`` file given the database connection to read from. + :param config: configuration from ``config.yaml``. :param db_conn: Source database connection. :param metadata: Source database metadata from ``orm.yaml``. diff --git a/datafaker/providers.py b/datafaker/providers.py index 65abf06..75006c7 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -38,7 +38,7 @@ def increment(self, db_connection: Connection, column: Column) -> int: """Return incrementing value for the column specified.""" name = f"{column.table.name}.{column.name}" result = self.accumulators.get(name, None) - if result == None: + if result is None: row = db_connection.execute(select(func.max(column))).first() result = 0 if row is None or row[0] is None else row[0] value = result + 1 diff --git a/datafaker/serialize_metadata.py b/datafaker/serialize_metadata.py index 0b96e34..b7f5cf2 100644 --- a/datafaker/serialize_metadata.py +++ b/datafaker/serialize_metadata.py @@ -1,6 +1,6 @@ """Convert between a Python dict describing a database schema and a SQLAlchemy MetaData.""" import typing -from typing import Callable +from functools import partial import parsy from sqlalchemy import Column, Dialect, Engine, ForeignKey, MetaData, Table @@ -27,9 +27,7 @@ def simple(type_: type) -> ParserType: def integer() -> ParserType: - """ - Get a parser for an integer, outputting that integer. - """ + """Get a parser for an integer, outputting that integer.""" return parsy.regex(r"-?[0-9]+").map(int) @@ -164,6 +162,7 @@ def type_parser() -> ParserType: def column_to_dict(column: Column, dialect: Dialect) -> dict[str, typing.Any]: """ Produce a dict description of a column. + :param column: The SQLAlchemy column to translate. :param dialect: The SQL dialect in which to render the type name. """ @@ -192,10 +191,11 @@ def dict_to_column( table_name: str, col_name: str, rep: dict, - ignore_fk: Callable[[str], bool], + ignore_fk: typing.Callable[[str], bool], ) -> Column: """ Produce column from aspects of its dict description. + :param table_name: The name of the table the column appears in. :param col_name: The name of the column. :param rep: The dict description of the column. @@ -249,7 +249,7 @@ def unique_to_dict(constraint: schema.UniqueConstraint) -> dict: def table_to_dict(table: Table, dialect: Dialect) -> TableT: - """Converts a SQL Alchemy Table object into a Python dict.""" + """Convert a SQL Alchemy Table object into a Python dict.""" return { "columns": { str(column.key): column_to_dict(column, dialect) @@ -267,7 +267,7 @@ def dict_to_table( name: str, meta: MetaData, table_dict: TableT, - ignore_fk: Callable[[str], bool], + ignore_fk: typing.Callable[[str], bool], ) -> Table: """Create a Table from its description.""" return Table( @@ -285,8 +285,9 @@ def metadata_to_dict( meta: MetaData, schema_name: str | None, engine: Engine ) -> dict[str, typing.Any]: """ - Converts a SQL Alchemy MetaData object into - a Python object ready for conversion to YAML. + Convert a metadata object into a Python dict. + + The output will be ready for output to ``orm.yaml``. """ return { "tables": { @@ -298,10 +299,13 @@ def metadata_to_dict( } -def should_ignore_fk(fk: str, tables_dict: dict[str, TableT]) -> bool: +def should_ignore_fk(tables_dict: dict[str, TableT], fk: str) -> bool: """ - Tell if this foreign key should be ignored because it points to an - ignored table. + Test if this foreign key points to an ignored table. + + If so, this foreign key should be ignored. + :param tables_dict: The ``tables`` value from ``config.yaml``. + :param fk: The name of the foreign key. """ fk_bits = fk.split(".", 2) if len(fk_bits) != 2: @@ -311,9 +315,13 @@ def should_ignore_fk(fk: str, tables_dict: dict[str, TableT]) -> bool: return bool(tables_dict[fk_bits[0]].get("ignore", False)) +def _always_false(_: str) -> bool: + return False + + def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaData: """ - Converts a dict to a SQL Alchemy MetaData object. + Convert a dict to a SQL Alchemy MetaData object. :param config_for_output: The configuration object. Should be None if the metadata object is being used for connecting to the source database. @@ -322,11 +330,12 @@ def dict_to_metadata(obj: dict, config_for_output: dict | None = None) -> MetaDa constraint to an ignored table. """ tables_dict = obj.get("tables", {}) + ignore_fk: typing.Callable[[str], bool] if config_for_output and "tables" in config_for_output: tables_config = config_for_output["tables"] - ignore_fk = lambda fk: should_ignore_fk(fk, tables_config) + ignore_fk = partial(should_ignore_fk, tables_config) else: - ignore_fk = lambda _: False + ignore_fk = _always_false meta = MetaData() for k, td in tables_dict.items(): dict_to_table(k, meta, td, ignore_fk) diff --git a/datafaker/utils.py b/datafaker/utils.py index b34664c..009109d 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -50,7 +50,6 @@ ) T = TypeVar("T") -_K = TypeVar("_K") def read_config_file(path: str) -> dict: @@ -96,14 +95,29 @@ def import_file(file_path: str) -> ModuleType: def open_file(file_name: str | Path) -> io.BufferedWriter: + """Open a file for writing.""" return Path(file_name).open("wb") def open_compressed_file(file_name: str | Path) -> gzip.GzipFile: + """ + Open a gzip-compressed file for writing. + + :param file_name: The name of the file to open. + :return: A file object; it can be written to as a normal uncompressed + file and it will do the compression. + """ return gzip.GzipFile(file_name, "wb") def table_row_count(table: Table, conn: Connection) -> int: + """ + Count the rows in the table. + + :param table: The table to count. + :param conn: The connection to the database. + :return: The number of rows in the table. + """ return conn.execute( select(sqlalchemy.func.count()).select_from( sqlalchemy.table( @@ -222,10 +236,12 @@ def warning_or_higher(record: logging.LogRecord) -> bool: class StdoutHandler(logging.Handler): """ A handler that writes to stdout. + We aren't using StreamHandler because that confuses typer.testing.CliRunner """ def flush(self) -> None: + """Flush the buffer.""" self.acquire() try: sys.stdout.flush() @@ -233,6 +249,7 @@ def flush(self) -> None: self.release() def emit(self, record: Any) -> None: + """Write the record.""" try: msg = self.format(record) sys.stdout.write(msg + "\n") @@ -246,10 +263,12 @@ def emit(self, record: Any) -> None: class StderrHandler(logging.Handler): """ A handler that writes to stderr. + We aren't using StreamHandler because that confuses typer.testing.CliRunner """ def flush(self) -> None: + """Flush the buffer.""" self.acquire() try: sys.stderr.flush() @@ -257,6 +276,7 @@ def flush(self) -> None: self.release() def emit(self, record: Any) -> None: + """Write the record.""" try: msg = self.format(record) sys.stderr.write(msg + "\n") @@ -293,19 +313,37 @@ def conf_logger(verbose: bool) -> None: logging.getLogger("blib2to3.pgen2.driver").setLevel(logging.WARNING) -def get_flag(maybe_dict: Any, key: Any) -> Any: - """Returns maybe_dict[key] or False if that doesn't exist""" - return type(maybe_dict) is dict and maybe_dict.get(key, False) +def get_flag(maybe_dict: Any, key: Any) -> bool: + """ + Get a boolean from a mapping, or False if that does not make sense. + + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :return: True only if ``maybe_dict`` is a mapping, ``maybe_dict[key]`` + exists and ``maybe_dict[key]`` is truthy. + """ + return isinstance(maybe_dict, Mapping) and maybe_dict.get(key, False) + +def get_property(maybe_dict: Any, key: Any, default: T) -> T: + """ + Get a specific property from a dict or a default if that does not exist. -def get_property(maybe_dict: Mapping[_K, Any], key: _K, default: T) -> T: - """Returns maybe_dict[key] or default if that doesn't exist""" - return maybe_dict.get(key, default) if type(maybe_dict) is dict else default + :param maybe_dict: A mapping, or possibly not. + :param key: A key in ``maybe_dict``, or possibly not. + :param default: The return value if ``maybe_dict`` is not a mapping, + or if ``key`` is not a key of ``maybe_dict``. + :return: ``maybe_dict[key]`` if this makes sense, or ``default`` if not. + """ + return maybe_dict.get(key, default) if isinstance(maybe_dict, Mapping) else default def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: """ - Does this foreign key refer to a table that is configured as ignore in config.yaml + Test if this foreign key refers to an ignored table. + + :param fk: The foreign key to test. + :return: True if the table referred to is ignored in ``config.yaml``. """ try: fk.column @@ -316,7 +354,10 @@ def fk_refers_to_ignored_table(fk: ForeignKey) -> bool: def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: """ - Does this foreign key constraint refer to a table that is configured as ignore in config.yaml + Test if the constraint refers to a table marked as ignored in ``config.yaml``. + + :param fk: The foreign key constraint. + :return: True if ``fk`` refers to an ignored table. """ try: fk.referred_table @@ -328,6 +369,10 @@ def fk_constraint_refers_to_ignored_table(fk: ForeignKeyConstraint) -> bool: def get_related_table_names(table: Table) -> set[str]: """ Get the names of all tables for which there exist foreign keys from this table. + + :param table: SQLAlchemy table object. + :return: The set of the names of the tables referred to by foreign keys + in ``table``. """ return { str(fk.referred_table.name) @@ -338,8 +383,11 @@ def get_related_table_names(table: Table) -> set[str]: def table_is_private(config: Mapping, table_name: str) -> bool: """ - Return True if the table with name table_name is a primary private table - according to config. + Test if the named table is private. + + :param config: The ``config.yaml`` object. + :param table_name: The name of the table to test. + :return: True if the table is marked as private in ``config``. """ ts = config.get("tables", {}) if type(ts) is not dict: @@ -351,10 +399,14 @@ def table_is_private(config: Mapping, table_name: str) -> bool: def primary_private_fks(config: Mapping, table: Table) -> list[str]: """ - Returns the list of columns in the table that refer to primary private tables. + Get the list of columns in the table that refer to primary private tables. A table that is not primary private but has a non-empty list of primary_private_fks is secondary private. + + :param config: The ``config.yaml`` object. + :param table: The table to examine. + :return: A list of names of columns that refer to private tables. """ return [ str(fk.referred_table.name) @@ -365,9 +417,7 @@ def primary_private_fks(config: Mapping, table: Table) -> list[str]: def get_vocabulary_table_names(config: Mapping) -> set[str]: - """ - Extract the table names with a vocabulary_table: true property. - """ + """Extract the table names with a vocabulary_table: true property.""" return { table_name for (table_name, table_config) in config.get("tables", {}).items() @@ -376,6 +426,7 @@ def get_vocabulary_table_names(config: Mapping) -> set[str]: def make_foreign_key_name(table_name: str, col_name: str) -> str: + """Make a suitable foreign key name.""" return f"{table_name}_{col_name}_fkey" @@ -384,6 +435,16 @@ def remove_vocab_foreign_key_constraints( config: Mapping[str, Any], dst_engine: Connection | Engine, ) -> None: + """ + Remove the foreign key constraints from vocabulary tables. + + This allows vocabulary tables to be loaded without worrying about + topologically sorting them or circular dependencies. + + :param metadata: The SQLAlchemy metadata from ``orm.yaml``. + :param config: The ``config.yaml`` object. + :param dst_engine: The destination database or a connection to it. + """ vocab_tables = get_vocabulary_table_names(config) for vocab_table_name in vocab_tables: vocab_table = metadata.tables[vocab_table_name] @@ -419,6 +480,7 @@ def reinstate_vocab_foreign_key_constraints( ) -> None: """ Put the removed foreign keys back into the destination database. + :param metadata: The SQLAlchemy metadata for the destination database. :param meta_dict: The ``orm.yaml`` configuration that ``metadata`` was created from. @@ -525,6 +587,14 @@ def topological_sort( def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Table]: + """ + Get the list of non-vocabulary tables, topologically sorted. + + :param metadata: SQLAlchemy database description. + :param config: The ``config.yaml`` object. + :return: The list of non-vocabulary tables, ordered such that the targets + of all the foreign keys come before their sources. + """ table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) @@ -537,8 +607,9 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta def underline_error(e: SyntaxError) -> str: - """ + r""" Make an underline for this error. + :return: string beginning ``\n`` then spaces then ``^^^^`` underlining the error, or a null string if this was not possible. """ @@ -553,7 +624,11 @@ def underline_error(e: SyntaxError) -> str: def generators_require_stats(config: Mapping) -> bool: """ - Returns true if any of the arguments for any of the generators reference SRC_STATS. + Test if the generator references ``SRC_STATS``. + + :param config: ``config.yaml`` object. + :return: True if any of the arguments for any of the generators + reference ``SRC_STATS``. """ ois = { f"object_instantiation.{k}": call diff --git a/tests/test_dump.py b/tests/test_dump.py index 7033e18..2340a6c 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -18,11 +18,11 @@ class DumpTests(RequiresDBTestCase): @patch("datafaker.dump._make_csv_writer") def test_dump_data(self, make_csv_writer: MagicMock) -> None: """Test dump-data.""" - TEST_OUTPUT_FILE = io.StringIO() + test_output_file = io.StringIO() metadata = MetaData() metadata.reflect(self.sync_engine) - dump_db_tables(metadata, self.dsn, self.schema_name, "player", TEST_OUTPUT_FILE) - make_csv_writer.assert_called_once_with(TEST_OUTPUT_FILE) + dump_db_tables(metadata, self.dsn, self.schema_name, "player", test_output_file) + make_csv_writer.assert_called_once_with(test_output_file) make_csv_writer.assert_has_calls( [ call().writerow(["id", "given_name", "family_name"]), diff --git a/tests/test_functional.py b/tests/test_functional.py index e60baa1..792dc40 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -349,33 +349,50 @@ def test_workflow_maximal_args(self) -> None: ) self.assertEqual("", completed_process.stderr) self.assertEqual( - { - "Creating data.", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.short_story'", - "Generating data for story 'story_generators.full_row_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for story 'story_generators.long_story'", - "Generating data for table 'data_type_test'", - "Generating data for table 'no_pk_test'", - "Generating data for table 'person'", - "Generating data for table 'strange_type_table'", - "Generating data for table 'unique_constraint_test'", - "Generating data for table 'unique_constraint_test2'", - "Generating data for table 'test_entity'", - "Generating data for table 'hospital_visit'", - "Data created in 2 passes.", - f"person: {2*(3+1+2+2)} rows created.", - f"hospital_visit: {2*(2*2+3)} rows created.", - "data_type_test: 2 rows created.", - "no_pk_test: 2 rows created.", - "strange_type_table: 2 rows created.", - "unique_constraint_test: 2 rows created.", - "unique_constraint_test2: 2 rows created.", - "test_entity: 2 rows created.", - }, - set(completed_process.stdout.split("\n")) - {""}, + sorted( + [ + "Creating data.", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.short_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.full_row_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for story 'story_generators.long_story'", + "Generating data for table 'data_type_test'", + "Generating data for table 'data_type_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'no_pk_test'", + "Generating data for table 'person'", + "Generating data for table 'person'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'strange_type_table'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'unique_constraint_test2'", + "Generating data for table 'test_entity'", + "Generating data for table 'test_entity'", + "Generating data for table 'hospital_visit'", + "Generating data for table 'hospital_visit'", + "Data created in 2 passes.", + f"person: {2*(3+1+2+2)} rows created.", + f"hospital_visit: {2*(2*2+3)} rows created.", + "data_type_test: 2 rows created.", + "no_pk_test: 2 rows created.", + "strange_type_table: 2 rows created.", + "unique_constraint_test: 2 rows created.", + "unique_constraint_test2: 2 rows created.", + "test_entity: 2 rows created.", + "", + ] + ), + sorted(completed_process.stdout.split("\n")), ) completed_process = self.invoke( @@ -452,8 +469,20 @@ def test_workflow_maximal_args(self) -> None: ) def invoke( - self, *args: Any, expected_error: str | None = None, env: Mapping[str, str] = {} + self, + *args: Any, + expected_error: str | None = None, + env: Mapping[str, str] | None = None, ) -> Result: + """ + Run datafaker with the given arguments and environment. + + :param args: Arguments to provide to datafaker. + :param expected_error: If None, will assert that the invocation + passes successfully without throwing an exception. Otherwise, + the suggested error must be present in the standard error stream. + :param env: The environment variables to be set during invocation. + """ res = self.runner.invoke(app, args, env=env) if expected_error is None: self.assertNoException(res) diff --git a/tests/test_interactive.py b/tests/test_interactive.py index a7803f0..c5e79d3 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -3,7 +3,7 @@ import random import re from dataclasses import dataclass -from typing import Any, Iterable, Mapping, MutableMapping +from typing import Any, Iterable, MutableMapping from unittest.mock import MagicMock, Mock, patch from sqlalchemy import insert, select @@ -46,7 +46,7 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: """Capture the printed table.""" self.columns = columns - def columnize(self, items: list[str] | None, displaywidth: int = 80) -> None: + def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: """Capture the printed table.""" if items is not None: self.column_items.append(items) @@ -587,7 +587,10 @@ def test_set_generator_distribution(self) -> None: self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", + ( + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column})" + f" AS stddev__{column} FROM {table}" + ), ) def test_set_generator_distribution_directly(self) -> None: @@ -608,7 +611,10 @@ def test_set_generator_distribution_directly(self) -> None: self.assertEqual(gc.config["src-stats"][0]["name"], f"auto__{table}") self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT AVG({column}) AS mean__{column}, STDDEV({column}) AS stddev__{column} FROM {table}", + ( + f"SELECT AVG({column}) AS mean__{column}, STDDEV({column})" + f" AS stddev__{column} FROM {table}" + ), ) def test_set_generator_choice(self) -> None: @@ -642,7 +648,11 @@ def test_set_generator_choice(self) -> None: ) self.assertEqual( gc.config["src-stats"][0]["query"], - f"SELECT {column} AS value FROM {table} WHERE {column} IS NOT NULL GROUP BY value ORDER BY COUNT({column}) DESC", + ( + f"SELECT {column} AS value FROM {table}" + f" WHERE {column} IS NOT NULL" + f" GROUP BY value ORDER BY COUNT({column}) DESC" + ), ) def test_weighted_choice_generator_generates_choices(self) -> None: @@ -761,7 +771,10 @@ def test_old_generators_remain(self) -> None: "src-stats": [ { "name": "auto__string", - "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + "query": ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), } ], } @@ -798,7 +811,10 @@ def test_old_generators_remain(self) -> None: self.assertEqual(gc.config["src-stats"][0]["name"], "auto__string") self.assertEqual( gc.config["src-stats"][0]["query"], - "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), ) def test_aggregate_queries_merge(self) -> None: @@ -824,7 +840,10 @@ def test_aggregate_queries_merge(self) -> None: "src-stats": [ { "name": "auto__string", - "query": "SELECT AVG(frequency) AS mean__frequency, STDDEV(frequency) AS stddev__frequency FROM string", + "query": ( + "SELECT AVG(frequency) AS mean__frequency," + " STDDEV(frequency) AS stddev__frequency FROM string" + ), } ], } @@ -1712,7 +1731,7 @@ class NonInteractiveTests(RequiresDBTestCase): ), ) def test_non_interactive_configure_generators( - self, mock_csv_reader: MagicMock, mock_path: MagicMock + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock ) -> None: """ test that we can set generators from a CSV file diff --git a/tests/test_main.py b/tests/test_main.py index a37eaf4..a570fe4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -194,7 +194,7 @@ def test_create_generators_with_force_enabled( @patch("datafaker.main.read_config_file") @patch("datafaker.main.load_metadata_for_output") def test_create_tables( - self, mock_load_meta: MagicMock, mock_config: MagicMock, mock_create: MagicMock + self, mock_load_meta: MagicMock, _mock_config: MagicMock, mock_create: MagicMock ) -> None: """Test the create-tables sub-command.""" diff --git a/tests/test_remove.py b/tests/test_remove.py index 24286fb..a6dbb85 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -18,12 +18,14 @@ class RemoveThingsTestCase(RequiresDBTestCase): schema_name = "public" def count_rows(self, connection: Connection, table_name: str) -> int | None: + """Count the rows in a table.""" return connection.execute( select(func.count()).select_from(self.metadata.tables[table_name]) ).scalar() @patch("datafaker.remove.get_settings") def test_remove_data(self, mock_get_settings: MagicMock) -> None: + """Test that data can be removed from non-vocabulary tables.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, @@ -67,6 +69,7 @@ def test_remove_data_raises(self, mock_get_settings: MagicMock) -> None: @patch("datafaker.remove.get_settings") def test_remove_vocab(self, mock_get_settings: MagicMock) -> None: + """Test that vocabulary tables can be removed.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, @@ -114,6 +117,7 @@ def test_remove_vocab_raises(self, mock_get_settings: MagicMock) -> None: @patch("datafaker.remove.get_settings") def test_remove_tables(self, mock_get_settings: MagicMock) -> None: + """Test that destination tables can be removed.""" mock_get_settings.return_value = Settings( src_dsn=self.dsn, dst_dsn=self.dsn, diff --git a/tests/test_rst.py b/tests/test_rst.py index 090658c..78f0747 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -7,6 +7,10 @@ from restructuredtext_lint import lint_file +def _level_to_string(level: int) -> str: + return ["Severe", "Error", "Warning"][level] + + class RstTests(TestCase): """Linting for the doc .rst files.""" @@ -44,7 +48,11 @@ def test_dir(self) -> None: ] if filtered_errors: - self.fail(msg="\n".join([ - f"{err.source}({err.line}): {["Severe", "Error", "Warning"][err.level]}: {err.full_message}" - for err in filtered_errors - ])) + self.fail( + msg="\n".join( + [ + f"{err.source}({err.line}): {_level_to_string(err.level)}: {err.full_message}" + for err in filtered_errors + ] + ) + ) From b86e10604ec85b08f073867d990096867c9f5fe2 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 9 Oct 2025 19:17:30 +0100 Subject: [PATCH 18/35] More cleaning --- datafaker/generators.py | 6 +++--- datafaker/main.py | 7 +++++-- datafaker/utils.py | 12 ++++++------ tests/test_interactive.py | 6 +++--- tests/test_rst.py | 20 ++++++++++++-------- 5 files changed, 29 insertions(+), 22 deletions(-) diff --git a/datafaker/generators.py b/datafaker/generators.py index 4e421af..a925b79 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -21,7 +21,7 @@ from datafaker.base import DistributionGenerator from datafaker.utils import T, logger -numeric = Union[int, float] +NumericT = Union[int, float] # How many distinct values can we have before we consider a # choice distribution to be infeasible? @@ -754,7 +754,7 @@ def get_generators( return [MimesisGenerator("person.weight")] -def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: +def fit_from_buckets(xs: Sequence[NumericT], ys: Sequence[NumericT]) -> float: """Calculate the fit by comparing a pair of lists of buckets.""" sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) @@ -764,7 +764,7 @@ def fit_from_buckets(xs: Sequence[numeric], ys: Sequence[numeric]) -> float: class ContinuousDistributionGenerator(Generator): """Base class for generators producing continuous distributions.""" - expected_buckets: Sequence[numeric] = [] + expected_buckets: Sequence[NumericT] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): """Initialise a ContinuousDistributionGenerator.""" diff --git a/datafaker/main.py b/datafaker/main.py index 42f6bde..c129983 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -437,7 +437,10 @@ def configure_generators( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), spec: Path = Option( None, - help="CSV file (headerless) with fields table-name, column-name, generator-name to set non-interactively", + help=( + "CSV file (headerless) with fields table-name," + " column-name, generator-name to set non-interactively" + ), ), ) -> None: """Interactively set generators for column data.""" @@ -482,7 +485,7 @@ def dump_data( if isinstance(sys.stdout, io.TextIOBase): dump_db_tables(metadata, dst_dsn, schema_name, table, sys.stdout) return - with open(output, "wt", newline="") as out: + with open(output, "wt", newline="", encoding="utf-8") as out: dump_db_tables(metadata, dst_dsn, schema_name, table, out) diff --git a/datafaker/utils.py b/datafaker/utils.py index 009109d..61089ab 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -390,11 +390,11 @@ def table_is_private(config: Mapping, table_name: str) -> bool: :return: True if the table is marked as private in ``config``. """ ts = config.get("tables", {}) - if type(ts) is not dict: + if not isinstance(ts, Mapping): return False t = ts.get(table_name, {}) ret = t.get("primary_private", False) - return ret if type(ret) is bool else False + return ret if isinstance(ret, bool) else False def primary_private_fks(config: Mapping, table: Table) -> list[str]: @@ -466,7 +466,7 @@ def remove_vocab_foreign_key_constraints( ) except ProgrammingError as e: session.rollback() - if type(e.orig) is UndefinedObject: + if isinstance(e.orig, UndefinedObject): logger.debug("Constraint does not exist") else: raise e @@ -501,7 +501,7 @@ def reinstate_vocab_foreign_key_constraints( name=make_foreign_key_name(vocab_table_name, column_name), refcolumns=fk_targets, ) - logger.debug(f"Restoring foreign key constraint {fk.name}") + logger.debug("Restoring foreign key constraint %s", fk.name) with Session(dst_engine) as session: session.begin() vocab_table.append_constraint(fk) @@ -598,12 +598,12 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta table_names = set(metadata.tables.keys()).difference( get_vocabulary_table_names(config) ) - (sorted, cycles) = topological_sort( + (sorted_tables, cycles) = topological_sort( table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: logger.warning(f"Cycle detected between tables: {cycle}") - return [metadata.tables[tn] for tn in sorted] + return [metadata.tables[tn] for tn in sorted_tables] def underline_error(e: SyntaxError) -> str: diff --git a/tests/test_interactive.py b/tests/test_interactive.py index c5e79d3..1082c33 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -46,10 +46,10 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: """Capture the printed table.""" self.columns = columns - def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: + def columnize(self, list: list[str] | None, _displaywidth: int = 80) -> None: """Capture the printed table.""" - if items is not None: - self.column_items.append(items) + if list is not None: + self.column_items.append(list) def ask_save(self) -> str: """Quitting always works without needing to ask the user.""" diff --git a/tests/test_rst.py b/tests/test_rst.py index 78f0747..29bf971 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -2,15 +2,26 @@ The CLI does not allow errors to be disabled, but we can ignore them here.""" from pathlib import Path +from typing import Any from unittest import TestCase from restructuredtext_lint import lint_file def _level_to_string(level: int) -> str: + """Get a string description of an error level.""" return ["Severe", "Error", "Warning"][level] +def _error_message(lint_error: Any) -> str: + """Turn a linting error into an error message.""" + source = getattr(lint_error, "source") + line = getattr(lint_error, "line") + level = _level_to_string(getattr(lint_error, "level")) + message = getattr(lint_error, "full_message") + return f"{source}({line}): {level}: {message}" + + class RstTests(TestCase): """Linting for the doc .rst files.""" @@ -48,11 +59,4 @@ def test_dir(self) -> None: ] if filtered_errors: - self.fail( - msg="\n".join( - [ - f"{err.source}({err.line}): {_level_to_string(err.level)}: {err.full_message}" - for err in filtered_errors - ] - ) - ) + self.fail(msg="\n".join(map(_error_message, filtered_errors))) From e1dec20231818a1af7e706f0c9fd187bd292612f Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 10 Oct 2025 18:44:37 +0100 Subject: [PATCH 19/35] Lots of pylint cleaning --- datafaker/base.py | 27 ++- datafaker/create.py | 9 +- datafaker/generators.py | 216 +++++++++++++++------- datafaker/interactive.py | 226 ++++++++++++----------- datafaker/main.py | 4 +- datafaker/make.py | 40 ++--- datafaker/utils.py | 24 +-- tests/test_functional.py | 36 ++-- tests/test_interactive.py | 318 +++++++++++++++------------------ tests/test_main.py | 21 +-- tests/test_make.py | 8 +- tests/test_providers.py | 3 +- tests/test_remove.py | 5 +- tests/test_rst.py | 2 +- tests/test_unique_generator.py | 1 - tests/utils.py | 39 ++-- 16 files changed, 522 insertions(+), 457 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 4ceb2af..6ff1890 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -5,10 +5,11 @@ import os import random from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass +from io import TextIOWrapper from pathlib import Path -from typing import Any, Callable, Generator +from typing import Any, Generator import numpy as np import yaml @@ -59,8 +60,7 @@ def merge_with_constants( yield xs[xi] xi += 1 outi += 1 - for x in xs[xi:]: - yield x + yield from xs[xi:] class NothingToGenerateException(Exception): @@ -132,7 +132,7 @@ def choice(self, a: list[T]) -> T: :return: The chosen value. """ c = random.choice(a) - return c["value"] if type(c) is dict and "value" in c else c + return c["value"] if isinstance(c, Mapping) and "value" in c else c def zipf_choice(self, a: list[T], n: int | None = None) -> T: """ @@ -149,7 +149,7 @@ def zipf_choice(self, a: list[T], n: int | None = None) -> T: if n is None: n = len(a) c = random.choices(a, weights=zipf_weights(n))[0] - return c["value"] if type(c) is dict and "value" in c else c + return c["value"] if isinstance(c, Mapping) and "value" in c else c def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ @@ -403,17 +403,26 @@ def load(self, connection: Connection, base_path: Path = Path(".")) -> None: """Load the data from file.""" yaml_file = base_path / Path(self.table.fullname + ".yaml") if yaml_file.exists(): - opener = lambda: open(yaml_file, mode="r", encoding="utf-8") + + def opener() -> TextIOWrapper: + return open(yaml_file, mode="r", encoding="utf-8") + else: yaml_file = base_path / Path(self.table.fullname + ".yaml.gz") if yaml_file.exists(): - opener = lambda: gzip.open(yaml_file, mode="rt") + + def opener() -> TextIOWrapper: + return gzip.open(yaml_file, mode="rt") + else: logger.warning("File %s not found. Skipping...", yaml_file) return if 0 < table_row_count(self.table, connection): logger.warning( - "Table %s already contains data (consider running 'datafaker remove-vocab'), skipping...", + ( + "Table %s already contains data" + " (consider running 'datafaker remove-vocab'), skipping..." + ), self.table.name, ) return diff --git a/datafaker/create.py b/datafaker/create.py index f11a0dd..ce2a74e 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -235,11 +235,10 @@ def next(self) -> None: if self._final_values is None: self._table_name, self._provided_values = next(self._story) return - else: - self._table_name, self._provided_values = self._story.send( - self._final_values - ) - return + self._table_name, self._provided_values = self._story.send( + self._final_values + ) + return except StopIteration: try: name, self._story = next(self._stories) diff --git a/datafaker/generators.py b/datafaker/generators.py index a925b79..0b760e4 100644 --- a/datafaker/generators.py +++ b/datafaker/generators.py @@ -21,7 +21,7 @@ from datafaker.base import DistributionGenerator from datafaker.utils import T, logger -NumericT = Union[int, float] +NumericType = Union[int, float] # How many distinct values can we have before we consider a # choice distribution to be infeasible? @@ -102,7 +102,8 @@ def custom_queries(self) -> dict[str, dict[str, str]]: "query": "SELECT one, too AS two FROM mytable WHERE too > 1", "comment": "big enough one and two from table mytable" }} - will populate SRC_STATS["myquery"]["results"][0]["one"] and SRC_STATS["myquery"]["results"][0]["two"] + will populate SRC_STATS["myquery"]["results"][0]["one"] + and SRC_STATS["myquery"]["results"][0]["two"] in the src-stats.yaml file. Keys should be chosen to minimize the chances of clashing with other queries, @@ -142,18 +143,17 @@ class PredefinedGenerator(Generator): def _get_src_stats_mentioned(self, val: Any) -> set[str]: if not val: return set() - if type(val) is str: + if isinstance(val, str): ss = self.SRC_STAT_NAME_RE.match(val) if ss: ss_name = ss.group(1) logger.debug("Found SRC_STATS reference %s", ss_name) return set([ss_name]) - else: - logger.debug("Value %s does not seem to be a SRC_STATS reference", val) - return set() - if type(val) is list: + logger.debug("Value %s does not seem to be a SRC_STATS reference", val) + return set() + if isinstance(val, list): return set.union(*(self._get_src_stats_mentioned(v) for v in val)) - if type(val) is dict: + if isinstance(val, dict): return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) return set() @@ -278,12 +278,9 @@ def __init__( with engine.connect() as connection: raw_buckets = connection.execute( text( - "SELECT COUNT({column}) AS f, FLOOR(({column} - {x})/{w}) AS b FROM {table} GROUP BY b".format( - column=column_name, - table=table_name, - x=mean - 2 * stddev, - w=stddev / 2, - ) + f"SELECT COUNT({column_name}) AS f," + f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" + f" FROM {table_name} GROUP BY b" ) ) self.buckets: Sequence[int] = [0] * 10 @@ -310,10 +307,9 @@ def make_buckets( with engine.connect() as connection: result = connection.execute( text( - "SELECT AVG({column}) AS mean, STDDEV({column}) AS stddev, COUNT({column}) AS count FROM {table}".format( - table=table_name, - column=column_name, - ) + f"SELECT AVG({column_name}) AS mean," + f" STDDEV({column_name}) AS stddev," + f" COUNT({column_name}) AS count FROM {table_name}" ) ).first() if result is None or result.stddev is None or getattr(result, "count") < 2: @@ -388,7 +384,8 @@ def __init__( f = getattr(f, part) if not callable(f): raise Exception( - f"Mimesis object {function_name} is not a callable, so cannot be used as a generator" + f"Mimesis object {function_name} is not a callable," + " so cannot be used as a generator" ) self._name = "generic." + function_name self._generator_function = f @@ -520,7 +517,7 @@ def __init__( @classmethod def make_singleton( - _cls, column: Column, engine: Engine, function_name: str + cls, column: Column, engine: Engine, function_name: str ) -> Sequence[Generator]: """Make the appropriate generation configuration for this column.""" extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" @@ -548,8 +545,14 @@ def make_singleton( def nominal_kwargs(self) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml``.""" return { - "start": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__start"]', - "end": f'SRC_STATS["auto__{self._column.table.name}"]["results"][0]["{self._column.name}__end"]', + "start": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__start"]' + ), + "end": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__end"]' + ), } def actual_kwargs(self) -> dict[str, Any]: @@ -564,11 +567,17 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: return { f"{self._column.name}__start": { "clause": self._min_year, - "comment": f"Earliest year found for column {self._column.name} in table {self._column.table.name}", + "comment": ( + f"Earliest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), }, f"{self._column.name}__end": { "clause": self._max_year, - "comment": f"Latest year found for column {self._column.name} in table {self._column.table.name}", + "comment": ( + f"Latest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), }, } @@ -642,7 +651,7 @@ def get_generators( f"LENGTH({column.name})", ) fitness_fn = len - except Exception as exc: + except Exception: # Some column types that appear to be strings (such as enums) # cannot have their lengths measured. In this case we cannot # detect fitness using lengths. @@ -754,7 +763,7 @@ def get_generators( return [MimesisGenerator("person.weight")] -def fit_from_buckets(xs: Sequence[NumericT], ys: Sequence[NumericT]) -> float: +def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: """Calculate the fit by comparing a pair of lists of buckets.""" sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) count = len(ys) @@ -764,7 +773,7 @@ def fit_from_buckets(xs: Sequence[NumericT], ys: Sequence[NumericT]) -> float: class ContinuousDistributionGenerator(Generator): """Base class for generators producing continuous distributions.""" - expected_buckets: Sequence[NumericT] = [] + expected_buckets: Sequence[NumericType] = [] def __init__(self, table_name: str, column_name: str, buckets: Buckets): """Initialise a ContinuousDistributionGenerator.""" @@ -776,8 +785,14 @@ def __init__(self, table_name: str, column_name: str, buckets: Buckets): def nominal_kwargs(self) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml``.""" return { - "mean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["mean__{self.column_name}"]', - "sd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["stddev__{self.column_name}"]', + "mean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["mean__{self.column_name}"]' + ), + "sd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["stddev__{self.column_name}"]' + ), } def actual_kwargs(self) -> dict[str, Any]: @@ -946,8 +961,14 @@ def generate_data(self, count: int) -> list[Any]: def nominal_kwargs(self) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml``.""" return { - "logmean": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logmean__{self.column_name}"]', - "logsd": f'SRC_STATS["auto__{self.table_name}"]["results"][0]["logstddev__{self.column_name}"]', + "logmean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logmean__{self.column_name}"]' + ), + "logsd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logstddev__{self.column_name}"]' + ), } def actual_kwargs(self) -> dict[str, Any]: @@ -963,12 +984,21 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: return { **clauses, f"logmean__{self.column_name}": { - "clause": f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name}) ELSE NULL END)", + "clause": ( + f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" + " ELSE NULL END)" + ), "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", }, f"logstddev__{self.column_name}": { - "clause": f"STDDEV(CASE WHEN 0<{self.column_name} THEN LN({self.column_name}) ELSE NULL END)", - "comment": f"Standard deviation of logs of {self.column_name} from table {self.table_name}", + "clause": ( + f"STDDEV(CASE WHEN 0<{self.column_name}" + f" THEN LN({self.column_name}) ELSE NULL END)" + ), + "comment": ( + f"Standard deviation of logs of {self.column_name}" + f" from table {self.table_name}" + ), }, } @@ -992,10 +1022,10 @@ def _get_generators_from_buckets( with engine.connect() as connection: result = connection.execute( text( - "SELECT AVG(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logmean, STDDEV(CASE WHEN 0<{column} THEN LN({column}) ELSE NULL END) AS logstddev FROM {table}".format( - table=table_name, - column=column_name, - ) + f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" + " ELSE NULL END) AS logmean," + f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" + f" AS logstddev FROM {table_name}" ) ).first() if result is None or result.logstddev is None: @@ -1064,21 +1094,56 @@ def __init__( extra_comment = " and their counts" if suppress_count == 0: if sample_count is None: - self._query = f"SELECT {column_name} AS value{extra_results} FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY COUNT({column_name}) DESC" - self._comment = f"All the values{extra_comment} that appear in column {column_name} of table {table_name}" + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL GROUP BY value" + f" ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name}" + ) self._annotation = None else: - self._query = f"SELECT {column_name} AS value{extra_results} FROM (SELECT {column_name} FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" - self._comment = f"The values{extra_comment} that appear in column {column_name} of a random sample of {sample_count} rows of table {table_name}" + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM" + f" (SELECT {column_name} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL" + f" ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"The values{extra_comment} that appear in column {column_name}" + f" of a random sample of {sample_count} rows of table {table_name}" + ) self._annotation = "sampled" else: if sample_count is None: - self._query = f"SELECT value{extra_expo} FROM (SELECT {column_name} AS value, COUNT({column_name}) AS count FROM {table_name} WHERE {column_name} IS NOT NULL GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count" - self._comment = f"All the values{extra_comment} that appear in column {column_name} of table {table_name} more than {suppress_count} times" + self._query = ( + f"SELECT value{extra_expo} FROM" + f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" + f" FROM {table_name} WHERE {column_name} IS NOT NULL" + f" GROUP BY value ORDER BY count DESC) AS _inner" + f" WHERE {suppress_count} < count" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name} more than {suppress_count} times" + ) self._annotation = "suppressed" else: - self._query = f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM (SELECT {column_name} AS value FROM {table_name} WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY value ORDER BY count DESC) AS _inner WHERE {suppress_count} < count" - self._comment = f"The values{extra_comment} that appear more than {suppress_count} times in column {column_name}, out of a random sample of {sample_count} rows of table {table_name}" + self._query = ( + f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" + f" (SELECT {column_name} AS value FROM {table_name}" + f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY count DESC)" + f" AS _inner WHERE {suppress_count} < count" + ) + self._comment = ( + f"The values{extra_comment} that appear more than {suppress_count} times" + f" in column {column_name}, out of a random sample of {sample_count} rows" + f" of table {table_name}" + ) self._annotation = "sampled and suppressed" @abstractmethod @@ -1220,14 +1285,14 @@ def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: if c != 0: counts.append(c) v = result.v - if type(v) is decimal.Decimal: + if isinstance(v, decimal.Decimal): v = float(v) values.append(v) cvs.append({"value": v, "count": c}) if suppress_count < c: counts_not_suppressed.append(c) v = result.v - if type(v) is decimal.Decimal: + if isinstance(v, decimal.Decimal): v = float(v) values_not_suppressed.append(v) cvs_not_suppressed.append({"value": v, "count": c}) @@ -1258,11 +1323,9 @@ def get_generators( with engine.connect() as connection: results = connection.execute( text( - "SELECT {column} AS v, COUNT({column}) AS f FROM {table} GROUP BY v ORDER BY f DESC LIMIT {limit}".format( - table=table_name, - column=column_name, - limit=MAXIMUM_CHOICES + 1, - ) + f"SELECT {column_name} AS v, COUNT({column_name})" + f" AS f FROM {table_name} GROUP BY v" + f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" ) ) if results is not None and results.rowcount <= MAXIMUM_CHOICES: @@ -1281,11 +1344,10 @@ def get_generators( ] results = connection.execute( text( - "SELECT v, COUNT(v) AS f FROM (SELECT {column} as v FROM {table} ORDER BY RANDOM() LIMIT {sample_count}) AS _inner GROUP BY v ORDER BY f DESC".format( - table=table_name, - column=column_name, - sample_count=self.SAMPLE_COUNT, - ) + f"SELECT v, COUNT(v) AS f FROM" + f" (SELECT {column_name} as v FROM {table_name}" + f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" + f" AS _inner GROUP BY v ORDER BY f DESC" ) ) if results is not None: @@ -1436,7 +1498,10 @@ def custom_queries(self) -> dict[str, Any]: cols = ", ".join(self._columns) return { f"auto__cov__{self._table}": { - "comment": f"Means and covariate matrix for the columns {cols}, so that we can produce the relatedness between these in the fake data.", + "comment": ( + f"Means and covariate matrix for the columns {cols}," + " so that we can produce the relatedness between these in the fake data." + ), "query": self._query, } } @@ -1511,14 +1576,20 @@ def query( ) means = "".join(f", _q.m{i}" for i in range(len(columns))) covs = "".join( - f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ( + f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" + f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ) for iy in range(len(columns)) for ix in range(iy + 1) ) if sample_count is None: subquery = table + where else: - subquery = f"(SELECT * FROM {table}{where} ORDER BY RANDOM() LIMIT {sample_count}) AS _sampled" + subquery = ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) # if there are any numeric columns we need at least two rows to make any (co)variances at all suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" return ( @@ -1689,7 +1760,10 @@ def _nominal_kwargs_with_combinations( self, index: int, partition: RowPartition ) -> dict[str, Any]: """Get the arguments to be entered into ``config.yaml`` for a single partition.""" - count = f'sum(r["count"] for r in SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + count = ( + 'sum(r["count"] for r in' + f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + ) if not partition.included_numeric and not partition.included_choice: return { "count": count, @@ -1799,12 +1873,12 @@ def fit(self, default: float = -1) -> float: def is_numeric(col: Column) -> bool: """Test if this column stores a numeric value.""" ct = get_column_type(col) - return (isinstance(ct, Numeric) or isinstance(ct, Integer)) and not col.foreign_keys + return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys -def powerset(input: list[T]) -> Iterable[Iterable[T]]: +def powerset(xs: list[T]) -> Iterable[Iterable[T]]: """Get a list of all sublists of ``input``.""" - return chain.from_iterable(combinations(input, n) for n in range(len(input) + 1)) + return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) @dataclass @@ -1911,7 +1985,11 @@ def get_partition_count_query( ) if where is None: return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' - return f'SELECT count, "index" FROM (SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index") AS _q {where}' + return ( + 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' + f' {index_exp} AS "index"' + f' FROM {table} GROUP BY "index") AS _q {where}' + ) def get_generators( self, columns: list[Column], engine: Engine @@ -1973,7 +2051,11 @@ def get_generators( partition_count_max_results = ( connection.execute(text(partition_query_max)).mappings().fetchall() ) - count_comment = f"Number of rows for each combination of the columns { {nc.column.name for nc in nullable_columns} } of the table {table} being null" + count_comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table} being null" + ) if self._execute_partition_queries(connection, row_partitions_maximal): gens.append( NullPartitionedNormalGenerator( diff --git a/datafaker/interactive.py b/datafaker/interactive.py index d35806d..17efa19 100644 --- a/datafaker/interactive.py +++ b/datafaker/interactive.py @@ -4,7 +4,7 @@ import functools import re from abc import ABC, abstractmethod -from collections.abc import Mapping, MutableMapping +from collections.abc import Mapping, MutableMapping, Sequence from dataclasses import dataclass from enum import Enum from pathlib import Path @@ -13,7 +13,7 @@ import sqlalchemy from prettytable import PrettyTable -from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table, text +from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table from typing_extensions import Self from datafaker.generators import Generator, PredefinedGenerator, everything_factory @@ -123,7 +123,9 @@ class DbCmd(ABC, cmd.Cmd): ROW_COUNT_MSG = "Total row count: {}" @abstractmethod - def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | None: + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableEntry | None: """ Make a table entry suitable for this interactive command. @@ -131,7 +133,6 @@ def make_table_entry(self, name: str, table_config: Mapping) -> TableEntry | Non :param table_config: The part of the ``config.yaml`` referring to this table. :return: The table entry or None if this table should not be interacted with. """ - ... def __init__( self, @@ -145,12 +146,12 @@ def __init__( self.config: MutableMapping[str, Any] = config self.metadata = metadata self._table_entries: list[TableEntry] = [] - tables_config: Mapping = config.get("tables", {}) - if type(tables_config) is not dict: + tables_config: MutableMapping = config.get("tables", {}) + if not isinstance(tables_config, MutableMapping): tables_config = {} for name in metadata.tables.keys(): table_config = tables_config.get(name, {}) - if type(table_config) is not dict: + if not isinstance(table_config, MutableMapping): table_config = {} entry = self.make_table_entry(name, table_config) if entry is not None: @@ -224,7 +225,7 @@ def set_prompt(self) -> None: """Set the prompt according to the current state.""" ... - def set_table_index(self, index: int) -> bool: + def _set_table_index(self, index: int) -> bool: """ Move to a different table. @@ -244,7 +245,7 @@ def next_table(self, report: str = "No more tables") -> bool: :param report: The text to print if there is no next table. :return: True if there is another table to move to. """ - if not self.set_table_index(self.table_index + 1): + if not self._set_table_index(self.table_index + 1): self.print(report) return False return True @@ -257,7 +258,7 @@ def table_metadata(self) -> Table: """Get the metadata of the current table.""" return self.metadata.tables[self.table_name()] - def get_column_names(self) -> list[str]: + def _get_column_names(self) -> list[str]: """Get the names of the current columns.""" return [col.name for col in self.table_metadata().columns] @@ -277,23 +278,25 @@ def report_columns(self) -> None: ], ) - def get_table_config(self, table_name: str) -> dict[str, Any]: + def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: """Get the configuration of the named table.""" ts = self.config.get("tables", None) - if type(ts) is not dict: + if not isinstance(ts, MutableMapping): return {} t = ts.get(table_name) - return t if type(t) is dict else {} + return t if isinstance(t, MutableMapping) else {} - def set_table_config(self, table_name: str, config: dict[str, Any]) -> None: + def set_table_config( + self, table_name: str, config: MutableMapping[str, Any] + ) -> None: """Set the configuration of the named table.""" ts = self.config.get("tables", None) - if type(ts) is not dict: + if not isinstance(ts, MutableMapping): self.config["tables"] = {table_name: config} return ts[table_name] = config - def _remove_prefix_src_stats(self, prefix: str) -> list[dict[str, Any]]: + def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: """Remove all source stats with the given prefix from the configuration.""" src_stats = self.config.get("src-stats", []) new_src_stats = [] @@ -323,7 +326,7 @@ def find_entry_index_by_table_name(self, table_name: str) -> int | None: None, ) - def find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: """Get the table entry of the named table.""" for e in self._table_entries: if e.name == table_name: @@ -339,11 +342,8 @@ def do_counts(self, _arg: str) -> None: colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] with self.sync_engine.connect() as connection: result = connection.execute( - text( - "SELECT COUNT(*) AS row_count{colcounts} FROM {table}".format( - table=table_name, - colcounts="".join(colcounts), - ) + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" ) ).first() if result is None: @@ -362,51 +362,52 @@ def do_counts(self, _arg: str) -> None: def do_select(self, arg: str) -> None: """Run a select query over the database and show the first 50 results.""" - MAX_SELECT_ROWS = 50 + max_select_rows = 50 with self.sync_engine.connect() as connection: try: - result = connection.execute(text("SELECT " + arg)) + result = connection.execute(sqlalchemy.text("SELECT " + arg)) except sqlalchemy.exc.DatabaseError as exc: self.print("Failed to execute: {}", exc) return row_count = result.rowcount self.print(self.ROW_COUNT_MSG, row_count) if 50 < row_count: - self.print("Showing the first {} rows", MAX_SELECT_ROWS) + self.print("Showing the first {} rows", max_select_rows) fields = list(result.keys()) - rows = [row._tuple() for row in result.fetchmany(MAX_SELECT_ROWS)] + rows = [row._tuple() for row in result.fetchmany(max_select_rows)] self.print_table(fields, rows) def do_peek(self, arg: str) -> None: """ View some data from the current table. - Use 'peek col1 col2 col3' to see a sample of values from columns col1, col2 and col3 in the current table. + Use 'peek col1 col2 col3' to see a sample of values from + columns col1, col2 and col3 in the current table. Use 'peek' to see a sample of the current column(s). Rows that are enitrely null are suppressed. """ - MAX_PEEK_ROWS = 25 + max_peek_rows = 25 if len(self._table_entries) <= self.table_index: return table_name = self.table_name() col_names = arg.split() if not col_names: - col_names = self.get_column_names() + col_names = self._get_column_names() nonnulls = [cn + " IS NOT NULL" for cn in col_names] with self.sync_engine.connect() as connection: - query = "SELECT {cols} FROM {table} {where} {nonnull} ORDER BY RANDOM() LIMIT {max}".format( - cols=",".join(col_names), - table=table_name, - where="WHERE" if nonnulls else "", - nonnull=" OR ".join(nonnulls), - max=MAX_PEEK_ROWS, + cols = (",".join(col_names),) + where = ("WHERE" if nonnulls else "",) + nonnull = (" OR ".join(nonnulls),) + query = sqlalchemy.text( + f"SELECT {cols} FROM {table_name} {where} {nonnull}" + f" ORDER BY RANDOM() LIMIT {max_peek_rows}" ) try: - result = connection.execute(text(query)) + result = connection.execute(query) except Exception as exc: self.print(f'SQL query "{query}" caused exception {exc}') return - rows = [row._tuple() for row in result.fetchmany(MAX_PEEK_ROWS)] + rows = [row._tuple() for row in result.fetchmany(max_peek_rows)] self.print_table(list(result.keys()), rows) def complete_peek( @@ -431,7 +432,10 @@ class TableCmdTableEntry(TableEntry): class TableCmd(DbCmd): """Command shell allowing the user to set the type of each table.""" - intro = "Interactive table configuration (ignore, vocabulary, private, generate or empty). Type ? for help.\n" + intro = ( + "Interactive table configuration (ignore," + " vocabulary, private, generate or empty). Type ? for help.\n" + ) doc_leader = """Use the commands 'ignore', 'vocabulary', 'private', 'empty' or 'generate' to set the table's type. Use 'next' or 'previous' to change table. Use 'tables' and 'columns' for @@ -453,7 +457,9 @@ class TableCmd(DbCmd): NOTE_TEXT_NO_CHANGES = "You have made no changes." NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | None: + def make_table_entry( + self, table_name: str, table: Mapping + ) -> TableCmdTableEntry | None: """ Make a table entry for the named table. @@ -462,14 +468,16 @@ def make_table_entry(self, name: str, table: Mapping) -> TableCmdTableEntry | No :return: The newly-constructed table entry. """ if table.get("ignore", False): - return TableCmdTableEntry(name, TableType.IGNORE, TableType.IGNORE) + return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) if table.get("vocabulary_table", False): - return TableCmdTableEntry(name, TableType.VOCABULARY, TableType.VOCABULARY) + return TableCmdTableEntry( + table_name, TableType.VOCABULARY, TableType.VOCABULARY + ) if table.get("primary_private", False): - return TableCmdTableEntry(name, TableType.PRIVATE, TableType.PRIVATE) + return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) if table.get("num_rows_per_pass", 1) == 0: - return TableCmdTableEntry(name, TableType.EMPTY, TableType.EMPTY) - return TableCmdTableEntry(name, TableType.GENERATE, TableType.GENERATE) + return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) + return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) def __init__( self, @@ -487,9 +495,9 @@ def table_entries(self) -> list[TableCmdTableEntry]: """Get the list of table entries.""" return cast(list[TableCmdTableEntry], self._table_entries) - def find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: """Get the table entry of the table with the given name.""" - entry = super().find_entry_by_table_name(table_name) + entry = super()._find_entry_by_table_name(table_name) if entry is None: return None return cast(TableCmdTableEntry, entry) @@ -556,7 +564,7 @@ def _sanity_check_failures(self) -> list[tuple[str, str, str]]: if from_t == TableType.VOCABULARY: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: - to_entry = self.find_entry_by_table_name(ref) + to_entry = self._find_entry_by_table_name(ref) if ( to_entry is not None and to_entry.new_type != TableType.VOCABULARY @@ -578,7 +586,7 @@ def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: if from_t in {TableType.GENERATE, TableType.PRIVATE}: referenced = self._get_referenced_tables(from_entry.name) for ref in referenced: - to_entry = self.find_entry_by_table_name(ref) + to_entry = self._find_entry_by_table_name(ref) if to_entry is not None and to_entry.new_type in { TableType.EMPTY, TableType.IGNORE, @@ -640,7 +648,7 @@ def do_next(self, arg: str) -> None: if index is None: self.print(self.ERROR_NO_SUCH_TABLE, arg) return - self.set_table_index(index) + self._set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) @@ -654,7 +662,7 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous table.""" - if not self.set_table_index(self.table_index - 1): + if not self._set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def do_ignore(self, _arg: str) -> None: @@ -676,7 +684,7 @@ def do_private(self, _arg: str) -> None: self.next_table() def do_generate(self, _arg: str) -> None: - """Set the current table as neither a vocabulary table nor ignored nor primary private, and go to the next table.""" + """Set the current table as to be generated, and go to the next table.""" self.set_type(TableType.GENERATE) self.print("Table {} generate", self.table_name()) self.next_table() @@ -758,7 +766,7 @@ def print_column_data(self, column: str, count: int, min_length: int) -> None: ) with self.sync_engine.connect() as connection: result = connection.execute( - text( + sqlalchemy.text( "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( table=self.table_name(), column=column, @@ -777,7 +785,7 @@ def print_row_data(self, count: int) -> None: """ with self.sync_engine.connect() as connection: result = connection.execute( - text( + sqlalchemy.text( "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( table=self.table_name(), count=count, @@ -879,7 +887,7 @@ def find_missingness_query( return None def make_table_entry( - self, name: str, table: Mapping + self, table_name: str, table_config: Mapping ) -> MissingnessCmdTableEntry | None: """ Make a table entry for a particular table. @@ -888,15 +896,15 @@ def make_table_entry( :param table: The part of ``config.yaml`` relating to this table. :return: The newly-constructed table entry. """ - if table.get("ignore", False): + if table_config.get("ignore", False): return None - if table.get("vocabulary_table", False): + if table_config.get("vocabulary_table", False): return None - if table.get("num_rows_per_pass", 1) == 0: + if table_config.get("num_rows_per_pass", 1) == 0: return None - mgs = table.get("missingness_generators", []) + mgs = table_config.get("missingness_generators", []) old = None - nonnull_columns = self.get_nonnull_columns(name) + nonnull_columns = self.get_nonnull_columns(table_name) if not nonnull_columns: return None if not mgs: @@ -909,7 +917,7 @@ def make_table_entry( elif len(mgs) == 1: mg = mgs[0] mg_name = mg.get("name", None) - if type(mg_name) is str: + if isinstance(mg_name, str): query_comment = self.find_missingness_query(mg) if query_comment is not None: (query, comment) = query_comment @@ -922,7 +930,7 @@ def make_table_entry( if old is None: return None return MissingnessCmdTableEntry( - name=name, + name=table_name, old_type=old, new_type=old, ) @@ -950,11 +958,11 @@ def table_entries(self) -> list[MissingnessCmdTableEntry]: """Get the table entries list.""" return cast(list[MissingnessCmdTableEntry], self._table_entries) - def find_entry_by_table_name( + def _find_entry_by_table_name( self, table_name: str ) -> MissingnessCmdTableEntry | None: """Find the table entry given the table name.""" - entry = super().find_entry_by_table_name(table_name) + entry = super()._find_entry_by_table_name(table_name) if entry is None: return None return cast(MissingnessCmdTableEntry, entry) @@ -965,9 +973,9 @@ def set_prompt(self) -> None: entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] nt = entry.new_type if nt is None: - self.prompt = "(missingness for {0}) ".format(entry.name) + self.prompt = f"(missingness for {entry.name}) " else: - self.prompt = "(missingness for {0}: {1}) ".format(entry.name, nt.name) + self.prompt = f"(missingness for {entry.name}: {nt.name}) " else: self.prompt = "(missingness) " @@ -985,14 +993,12 @@ def _copy_entries(self) -> None: if entry.new_type is None or entry.new_type.name == "none": table.pop("missingness_generators", None) else: - src_stat_key = "missing_auto__{0}__0".format(entry.name) + src_stat_key = f"missing_auto__{entry.name}__0" table["missingness_generators"] = [ { "name": entry.new_type.name, "kwargs": { - "patterns": 'SRC_STATS["{0}"]["results"]'.format( - src_stat_key - ) + "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' }, "columns": entry.new_type.columns, } @@ -1048,7 +1054,7 @@ def do_tables(self, _arg: str) -> None: for entry in self.table_entries: old = "-" if entry.old_type is None else entry.old_type.name new = "-" if entry.new_type is None else entry.new_type.name - desc = new if old == new else "{0}->{1}".format(old, new) + desc = new if old == new else f"{old}->{new}" self.print("{0} {1}", entry.name, desc) def do_next(self, arg: str) -> None: @@ -1066,7 +1072,7 @@ def do_next(self, arg: str) -> None: if index is None: self.print(self.ERROR_NO_SUCH_TABLE, arg) return - self.set_table_index(index) + self._set_table_index(index) return self.next_table(self.INFO_NO_MORE_TABLES) @@ -1080,7 +1086,7 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous table.""" - if not self.set_table_index(self.table_index - 1): + if not self._set_table_index(self.table_index - 1): self.print(self.ERROR_ALREADY_AT_START) def _set_type(self, name: str, query: str, comment: str) -> None: @@ -1130,7 +1136,10 @@ def do_sampled(self, arg: str) -> None: count, self.get_nonnull_columns(entry.name), ), - f"The missingness patterns and how often they appear in a sample of {count} from table {entry.name}", + ( + "The missingness patterns and how often they appear in a" + f" sample of {count} from table {entry.name}" + ), ) self.print("Table {} set to sampled missingness", self.table_name()) self.next_table() @@ -1219,7 +1228,7 @@ class GeneratorCmd(DbCmd): ) def make_table_entry( - self, table_name: str, table: Mapping + self, table_name: str, table_config: Mapping ) -> GeneratorCmdTableEntry | None: """ Make a table entry. @@ -1228,11 +1237,11 @@ def make_table_entry( :param table: The portion of the ``config.yaml`` file describing this table. :return: The newly constructed table entry, or None if this table is to be ignored. """ - if table.get("ignore", False): + if table_config.get("ignore", False): return None - if table.get("vocabulary_table", False): + if table_config.get("vocabulary_table", False): return None - if table.get("num_rows_per_pass", 1) == 0: + if table_config.get("num_rows_per_pass", 1) == 0: return None metadata_table = self.metadata.tables[table_name] columns = [str(colname) for colname in metadata_table.columns.keys()] @@ -1241,7 +1250,7 @@ def make_table_entry( new_generator_infos: list[GeneratorInfo] = [] old_generator_infos: list[GeneratorInfo] = [] - for rg in table.get("row_generators", []): + for rg in table_config.get("row_generators", []): gen_name = rg.get("name", None) if gen_name: ca = rg.get("columns_assigned", []) @@ -1310,6 +1319,7 @@ def __init__( :param config: Configuration loaded from ``config.yaml`` """ super().__init__(src_dsn, src_schema, metadata, config) + self.generators: list[Generator] | None = None self.generator_index = 0 self.generators_valid_columns: Optional[tuple[int, list[str]]] = None self.set_prompt() @@ -1319,7 +1329,7 @@ def table_entries(self) -> list[GeneratorCmdTableEntry]: """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" return cast(list[GeneratorCmdTableEntry], self._table_entries) - def find_entry_by_table_name( + def _find_entry_by_table_name( self, table_name: str ) -> GeneratorCmdTableEntry | None: """ @@ -1328,30 +1338,30 @@ def find_entry_by_table_name( :param table_name: The name of the table to find. :return: The table entry, or None if no such table name exists. """ - entry = super().find_entry_by_table_name(table_name) + entry = super()._find_entry_by_table_name(table_name) if entry is None: return None return cast(GeneratorCmdTableEntry, entry) - def set_table_index(self, index: int) -> bool: + def _set_table_index(self, index: int) -> bool: """ Move to a new table. :param index: table index to move to. """ - ret = super().set_table_index(index) + ret = super()._set_table_index(index) if ret: self.generator_index = 0 self.set_prompt() return ret - def previous_table(self) -> bool: + def _previous_table(self) -> bool: """ Move to the table before the current one. :return: True if there is a previous table to go to. """ - ret = self.set_table_index(self.table_index - 1) + ret = self._set_table_index(self.table_index - 1) if ret: table = self.get_table() if table is None: @@ -1371,7 +1381,7 @@ def get_table(self) -> GeneratorCmdTableEntry | None: return self.table_entries[self.table_index] return None - def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: """Get a pair; the table name then the generator information.""" if self.table_index < len(self.table_entries): entry = self.table_entries[self.table_index] @@ -1380,26 +1390,26 @@ def get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: return (entry.name, None) return (None, None) - def get_column_names(self) -> list[str]: + def _get_column_names(self) -> list[str]: """Get the (unqualified) names for all the current columns.""" - (_, generator_info) = self.get_table_and_generator() + (_, generator_info) = self._get_table_and_generator() return generator_info.columns if generator_info else [] - def column_metadata(self) -> list[Column]: + def _column_metadata(self) -> list[Column]: """Get the metadata for all the current columns.""" table = self.table_metadata() if table is None: return [] - return [table.columns[name] for name in self.get_column_names()] + return [table.columns[name] for name in self._get_column_names()] def set_prompt(self) -> None: """Set the prompt according to the current table, column and generator.""" - (table_name, gen_info) = self.get_table_and_generator() + (table_name, gen_info) = self._get_table_and_generator() if table_name is None: self.prompt = "(generators) " return if gen_info is None: - self.prompt = "({table}) ".format(table=table_name) + self.prompt = f"({table_name}) " return table = self.table_metadata() columns = [ @@ -1408,7 +1418,7 @@ def set_prompt(self) -> None: gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - def _remove_auto_src_stats(self) -> list[dict[str, Any]]: + def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: """ Remove all automatic source stats. @@ -1510,7 +1520,7 @@ def do_quit(self, arg: str) -> bool: return True return False - def do_tables(self, arg: str) -> None: + def do_tables(self, _arg: str) -> None: """List the tables.""" for t_entry in self.table_entries: entry = cast(GeneratorCmdTableEntry, t_entry) @@ -1518,7 +1528,7 @@ def do_tables(self, arg: str) -> None: how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" self.print("{0} ({1})", entry.name, how_many) - def do_list(self, arg: str) -> None: + def do_list(self, _arg: str) -> None: """List the generators in the current table.""" if len(self.table_entries) <= self.table_index: self.print("Error: no table {0}", self.table_index) @@ -1547,7 +1557,7 @@ def do_columns(self, _arg: str) -> None: def do_info(self, _arg: str) -> None: """Show information about the current column.""" - for cm in self.column_metadata(): + for cm in self._column_metadata(): self.print( "Column {0} in table {1} has type {2} ({3}).", cm.name, @@ -1614,7 +1624,7 @@ def go_to(self, target: str) -> bool: if gen_index is None: self.print("we cannot set the generator for column {0}", parts[1]) return False - self.set_table_index(table_index) + self._set_table_index(table_index) if gen_index is not None: self.generator_index = gen_index self.set_prompt() @@ -1694,7 +1704,7 @@ def complete_next( def do_previous(self, _arg: str) -> None: """Go to the previous generator.""" if self.generator_index == 0: - self.previous_table() + self._previous_table() else: self.generator_index -= 1 self.set_prompt() @@ -1707,7 +1717,7 @@ def _generators_valid(self) -> bool: """Test if ``self.generators`` is still correct for the current columns.""" return self.generators_valid_columns == ( self.table_index, - self.get_column_names(), + self._get_column_names(), ) def _get_generator_proposals(self) -> list[Generator]: @@ -1715,13 +1725,13 @@ def _get_generator_proposals(self) -> list[Generator]: if not self._generators_valid(): self.generators = None if self.generators is None: - columns = self.column_metadata() + columns = self._column_metadata() gens = everything_factory().get_generators(columns, self.sync_engine) sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) self.generators = sorted_gens self.generators_valid_columns = ( self.table_index, - self.get_column_names().copy(), + self._get_column_names().copy(), ) return self.generators @@ -1762,7 +1772,7 @@ def do_compare(self, arg: str) -> None: for argument in args: if argument.isdigit(): n = int(argument) - if 0 < n and n <= len(gens): + if 0 < n <= len(gens): gen = gens[n - 1] comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) self._print_values_queried(table_name, n, gen) @@ -1822,7 +1832,7 @@ def _print_custom_queries(self, gen: Generator) -> None: def _get_custom_queries_from( self, out: dict[str, Any], nominal: Any, actual: Any ) -> None: - if type(nominal) is str: + if isinstance(nominal, str): src_stat_groups = self.SRC_STAT_RE.search(nominal) # Do we have a SRC_STAT reference? if src_stat_groups: @@ -1834,10 +1844,10 @@ def _get_custom_queries_from( actual = {sub: actual} else: out[cq_key] = actual - elif type(nominal) is list and type(actual) is list: + elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): for i in range(min(len(nominal), len(actual))): self._get_custom_queries_from(out, nominal[i], actual[i]) - elif type(nominal) is dict and type(actual) is dict: + elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): for k, v in nominal.items(): if k in actual: self._get_custom_queries_from(out, v, actual[k]) @@ -1892,12 +1902,12 @@ def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None def _get_column_data( self, count: int, to_str: Callable[[Any], str] = repr ) -> list[list[str]]: - columns = self.get_column_names() + columns = self._get_column_names() columns_string = ", ".join(columns) pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) with self.sync_engine.connect() as connection: result = connection.execute( - text( + sqlalchemy.text( f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" ) ) @@ -1977,7 +1987,7 @@ def do_set(self, arg: str) -> None: def set_generator(self, gen: Generator | None) -> None: """Set the current column's generator.""" - (table, gen_info) = self.get_table_and_generator() + (table, gen_info) = self._get_table_and_generator() if table is None: self.print("Error: no table") return @@ -2155,7 +2165,7 @@ def update_config_generators( if line: if len(line) != 3: logger.error( - "line {0} of file {1} does not have three values", + "line %d of file %s does not have three values", line_no, spec_path, ) diff --git a/datafaker/main.py b/datafaker/main.py index c129983..cf7bd3b 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -303,7 +303,6 @@ def make_vocab( @app.command() def make_stats( - orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), stats_file: str = Option(STATS_FILENAME), force: bool = Option( @@ -324,13 +323,12 @@ def make_stats( _check_file_non_existence(stats_file_path) config = read_config_file(config_file) if config_file is not None else {} - orm_metadata = load_metadata(orm_file, config) settings = get_settings() src_dsn: str = _require_src_db_dsn(settings) src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(src_dsn, config, orm_metadata, settings.src_schema) + make_src_stats(src_dsn, config, settings.src_schema) ) stats_file_path.write_text(yaml.dump(src_stats), encoding="utf-8") logger.debug("%s created.", stats_file) diff --git a/datafaker/make.py b/datafaker/make.py index 096ee8b..bca7a2b 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -139,15 +139,15 @@ class StoryGeneratorInfo: def _render_value(v: Any) -> str: - if type(v) is list: + if isinstance(v, list): return "[" + ", ".join(_render_value(x) for x in v) + "]" - if type(v) is set: + if isinstance(v, set): return "{" + ", ".join(_render_value(x) for x in v) + "}" - if type(v) is dict: + if isinstance(v, dict): return ( "{" + ", ".join(f"{repr(k)}:{_render_value(x)}" for k, x in v.items()) + "}" ) - if type(v) is str: + if isinstance(v, str): return v return str(v) @@ -603,8 +603,10 @@ def make_table_generators( # pylint: disable=too-many-locals Args: metadata: database ORM config: Configuration to control the generator creation. - orm_filename: "orm.yaml" file path so that the generator file can load the MetaData object - config_filename: "config.yaml" file path so that the generator file can load the MetaData object + orm_filename: "orm.yaml" file path so that the generator + file can load the MetaData object + config_filename: "config.yaml" file path so that the generator + file can load the MetaData object src_stats_filename: A filename for where to read src stats from. Optional, if `None` this feature will be skipped overwrite_files: Whether to overwrite pre-existing vocabulary files @@ -765,7 +767,8 @@ async def __aexit__( """Exit the ``with`` section, closing the connection.""" if isinstance(self._connection, AsyncConnection): await self._connection.close() - self._connection.close() + else: + self._connection.close() async def execute_raw_query(self, query: Executable) -> CursorResult: """Execute the query on the owned connection.""" @@ -808,7 +811,7 @@ async def execute_query(self, query_block: Mapping[str, Any]) -> Any: def fix_type(value: Any) -> Any: """Make this value suitable for yaml output.""" - if type(value) is decimal.Decimal: + if isinstance(value, decimal.Decimal): return float(value) return value @@ -819,37 +822,34 @@ def fix_types(dics: list[dict]) -> list[dict]: async def make_src_stats( - dsn: str, config: Mapping, metadata: MetaData, schema_name: Optional[str] = None + dsn: str, config: Mapping, schema_name: Optional[str] = None ) -> dict[str, dict[str, Any]]: - """Run the src-stats queries specified by the configuration. + """ + Run the src-stats queries specified by the configuration. Query the src database with the queries in the src-stats block of the `config` dictionary, using the differential privacy parameters set in the `smartnoise-sql` block of `config`. Record the results in a dictionary and return it. - Args: - dsn: database connection string - config: a dictionary with the necessary configuration - metadata: the database ORM - schema_name: name of the database schema - Returns: - The dictionary of src-stats. + :param dsn: database connection string + :param config: a dictionary with the necessary configuration + :param schema_name: name of the database schema + :return: The dictionary of src-stats. """ use_asyncio = config.get("use-asyncio", False) engine = create_db_engine(dsn, schema_name=schema_name, use_asyncio=use_asyncio) async with DbConnection(engine) as db_conn: - return await make_src_stats_connection(config, db_conn, metadata) + return await make_src_stats_connection(config, db_conn) async def make_src_stats_connection( - config: Mapping, db_conn: DbConnection, metadata: MetaData + config: Mapping, db_conn: DbConnection ) -> dict[str, dict[str, Any]]: """ Make the ``src-stats.yaml`` file given the database connection to read from. :param config: configuration from ``config.yaml``. :param db_conn: Source database connection. - :param metadata: Source database metadata from ``orm.yaml``. """ date_string = datetime.today().strftime("%Y-%m-%d %H:%M:%S") query_blocks = config.get("src-stats", []) diff --git a/datafaker/utils.py b/datafaker/utils.py index 61089ab..6d0041d 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -6,19 +6,10 @@ import json import logging import sys +from collections.abc import Mapping, Sequence from pathlib import Path from types import ModuleType -from typing import ( - Any, - Callable, - Final, - Generator, - Iterable, - Mapping, - Optional, - TypeVar, - Union, -) +from typing import Any, Callable, Final, Generator, Iterable, Optional, TypeVar, Union import sqlalchemy import yaml @@ -119,6 +110,7 @@ def table_row_count(table: Table, conn: Connection) -> int: :return: The number of rows in the table. """ return conn.execute( + # pylint: disable=not-callable select(sqlalchemy.func.count()).select_from( sqlalchemy.table( table.name, @@ -527,7 +519,7 @@ def stream_yaml(yaml_file_handle: io.TextIOBase) -> Generator[Any, None, None]: if not line or line.startswith("-"): if buf: yl = yaml.load(buf, yaml.Loader) - assert type(yl) is list and len(yl) == 1 + assert isinstance(yl, Sequence) and len(yl) == 1 yield yl[0] if not line: return @@ -602,7 +594,7 @@ def sorted_non_vocabulary_tables(metadata: MetaData, config: Mapping) -> list[Ta table_names, lambda tn: get_related_table_names(metadata.tables[tn]) ) for cycle in cycles: - logger.warning(f"Cycle detected between tables: {cycle}") + logger.warning("Cycle detected between tables: %s", cycle) return [metadata.tables[tn] for tn in sorted_tables] @@ -652,7 +644,7 @@ def generators_require_stats(config: Mapping) -> bool: names = ( node.id for node in ast.walk(ast.parse(arg)) - if type(node) is ast.Name + if isinstance(node, ast.Name) ) if any(name == "SRC_STATS" for name in names): stats_required = True @@ -668,12 +660,12 @@ def generators_require_stats(config: Mapping) -> bool: ) ) for k, arg in call.get("kwargs", {}).items(): - if type(arg) is str: + if isinstance(arg, str): try: names = ( node.id for node in ast.walk(ast.parse(arg)) - if type(node) is ast.Name + if isinstance(node, ast.Name) ) if any(name == "SRC_STATS" for name in names): stats_required = True diff --git a/tests/test_functional.py b/tests/test_functional.py index 792dc40..ac7e51a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -81,6 +81,13 @@ def tearDown(self) -> None: os.chdir(self.start_dir) super().tearDown() + def assert_silent_success(self, completed_process: Result) -> None: + """Assert that the process completed successfully without producing output.""" + self.assertNoException(completed_process) + self.assertSuccess(completed_process) + self.assertEqual(completed_process.stderr, "") + self.assertEqual(completed_process.stdout, "") + def test_workflow_minimal_args(self) -> None: """Test the recommended CLI workflow runs without errors.""" shutil.copy(self.config_file_path, "config.yaml") @@ -88,26 +95,19 @@ def test_workflow_minimal_args(self) -> None: "make-tables", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stderr, "") - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-vocab", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-stats", "--force", ) - self.assertNoException(completed_process) - self.assertSuccess(completed_process) - self.assertEqual(completed_process.stdout, "") + self.assert_silent_success(completed_process) completed_process = self.invoke( "create-generators", @@ -138,27 +138,18 @@ def test_workflow_minimal_args(self) -> None: completed_process = self.invoke( "create-tables", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke( "create-vocab", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke( "make-stats", "--force", ) - self.assertNoException(completed_process) - self.assertEqual("", completed_process.stderr) - self.assertSuccess(completed_process) - self.assertEqual("", completed_process.stdout) + self.assert_silent_success(completed_process) completed_process = self.invoke("create-data") self.assertNoException(completed_process) @@ -514,7 +505,6 @@ def test_unique_constraint_fail(self) -> None: "make-stats", f"--stats-file={self.stats_file_path}", f"--config-file={self.config_file_path}", - f"--orm-file={self.alt_orm_file_path}", "--force", ) self.invoke( diff --git a/tests/test_interactive.py b/tests/test_interactive.py index 1082c33..fcb5ced 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive.py @@ -4,9 +4,10 @@ import re from dataclasses import dataclass from typing import Any, Iterable, MutableMapping +from unittest import TestCase from unittest.mock import MagicMock, Mock, patch -from sqlalchemy import insert, select +from sqlalchemy import Connection, MetaData, insert, select from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.interactive import ( @@ -20,6 +21,8 @@ class TestDbCmdMixin(DbCmd): + """A mixin for capturing output from interactive commands.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize a TestDbCmdMixin""" super().__init__(*args, **kwargs) @@ -46,10 +49,11 @@ def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: """Capture the printed table.""" self.columns = columns - def columnize(self, list: list[str] | None, _displaywidth: int = 80) -> None: + # pylint: disable=arguments-renamed + def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: """Capture the printed table.""" - if list is not None: - self.column_items.append(list) + if items is not None: + self.column_items.append(items) def ask_save(self) -> str: """Quitting always works without needing to ask the user.""" @@ -666,11 +670,11 @@ def test_weighted_choice_generator_generates_choices(self) -> None: gc.do_propose("") proposals = gc.get_proposals() gen_proposal = proposals[generator] - self.assertSubset(set(gen_proposal[2]), {str(v) for v in values}) + self.assert_subset(set(gen_proposal[2]), {str(v) for v in values}) gc.do_compare(str(gen_proposal[0])) col_heading = f"{gen_proposal[0]}. {generator}" self.assertIn(col_heading, gc.columns) - self.assertSubset(set(gc.columns[col_heading]), values) + self.assert_subset(set(gc.columns[col_heading]), values) def test_merge_columns(self) -> None: """Test that we can merge columns and set a multivariate generator""" @@ -822,21 +826,16 @@ def test_aggregate_queries_merge(self) -> None: Test that we can set a generator that requires select aggregate clauses and keep an old one, resulting in a merged query. """ - config = { - "tables": { - "string": { - "row_generators": [ - { - "name": "dist_gen.normal", - "columns_assigned": ["frequency"], - "kwargs": { - "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', - "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', - }, - } - ] - } + rg = { + "name": "dist_gen.normal", + "columns_assigned": ["frequency"], + "kwargs": { + "mean": 'SRC_STATS["auto__string"]["results"][0]["mean__frequency"]', + "sd": 'SRC_STATS["auto__string"]["results"][0]["stddev__frequency"]', }, + } + config = { + "tables": {"string": {"row_generators": [rg]}}, "src-stats": [ { "name": "auto__string", @@ -1096,7 +1095,6 @@ def test_create_with_choice(self) -> None: def test_create_with_weighted_choice(self) -> None: """Smoke test weighted choice.""" - table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") gc.reset() @@ -1108,7 +1106,7 @@ def test_create_with_weighted_choice(self) -> None: "dist_gen.weighted_choice [sampled and suppressed]", proposals ) prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] - self.assertSubset(set(prop[2]), {"1", "4"}) + self.assert_subset(set(prop[2]), {"1", "4"}) gc.reset() gc.do_compare(str(prop[0])) col_heading = ( @@ -1116,7 +1114,7 @@ def test_create_with_weighted_choice(self) -> None: ) self.assertIn(col_heading, set(gc.columns.keys())) col_set: set[int] = set(gc.columns[col_heading]) - self.assertSubset(col_set, {1, 4}) + self.assert_subset(col_set, {1, 4}) gc.do_set(str(prop[0])) gc.do_next("number_table.two") gc.reset() @@ -1128,13 +1126,13 @@ def test_create_with_weighted_choice(self) -> None: "dist_gen.weighted_choice [sampled and suppressed]", proposals ) prop = proposals["dist_gen.weighted_choice"] - self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) + self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.reset() gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice" self.assertIn(col_heading, set(gc.columns.keys())) col_set2: set[int] = set(gc.columns[col_heading]) - self.assertSubset(col_set2, {1, 2, 3, 4, 5}) + self.assert_subset(col_set2, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_next("number_table.three") gc.reset() @@ -1146,22 +1144,22 @@ def test_create_with_weighted_choice(self) -> None: "dist_gen.weighted_choice [sampled and suppressed]", proposals ) prop = proposals["dist_gen.weighted_choice [sampled]"] - self.assertSubset(set(prop[2]), {"1", "2", "3", "4", "5"}) + self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) gc.do_compare(str(prop[0])) col_heading = f"{prop[0]}. dist_gen.weighted_choice [sampled]" self.assertIn(col_heading, set(gc.columns.keys())) col_set3: set[int] = set(gc.columns[col_heading]) - self.assertSubset(col_set3, {1, 2, 3, 4, 5}) + self.assert_subset(col_set3, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() ones = set() twos = set() threes = set() - for row in rows: + for row in conn.execute( + select(self.metadata.tables["number_table"]) + ).fetchall(): ones.add(row.one) twos.add(row.two) threes.add(row.three) @@ -1447,6 +1445,60 @@ def covar(self) -> float: return (self.xy - self.x * self.y / self.n) / (self.n - 1) +class EavMeasurementTableStats: + """The statistics for the Measurement table of eav.sql.""" + + def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: + stmt = select(metadata.tables["measurement"]) + rows = conn.execute(stmt).fetchall() + self.types: set[int] = set() + self.one_count = 0 + self.one_yes_count = 0 + self.two = Correlation() + self.three = Correlation() + self.four = Correlation() + self.fish = Stat() + self.fowl = Stat() + for row in rows: + self.types.add(row.type) + if row.type == 1: + # yes or no + test.assertIsNone(row.first_value) + test.assertIsNone(row.second_value) + test.assertIn(row.third_value, {"yes", "no"}) + self.one_count += 1 + if row.third_value == "yes": + self.one_yes_count += 1 + elif row.type == 2: + # positive correlation around 1.4, 1.8 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.two.add2(row.first_value, row.second_value) + elif row.type == 3: + # negative correlation around 11.8, 12.1 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.three.add2(row.first_value, row.second_value) + elif row.type == 4: + # positive correlation around 21.4, 23.4 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.four.add2(row.first_value, row.second_value) + elif row.type == 5: + test.assertIn(row.third_value, {"fish", "fowl"}) + test.assertIsNotNone(row.first_value) + test.assertIsNone(row.second_value) + if row.third_value == "fish": + # mean 8.1 and sd 0.755 + self.fish.add(row.first_value) + else: + # mean 11.2 and sd 1.114 + self.fowl.add(row.first_value) + + class NullPartitionedTests(GeneratesDBTestCase): """Testing null-partitioned grouped multivariate generation.""" @@ -1466,7 +1518,6 @@ def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: def test_create_with_null_partitioned_grouped_multivariate(self) -> None: """Test EAV for all columns.""" - table_name = "measurement" generate_count = 800 with self._get_cmd({}) as gc: gc.do_next("measurement.type") @@ -1502,96 +1553,49 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: conn.commit() self.create_data(gc.config, num_passes=generate_count) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - one_count = 0 - one_yes_count = 0 - two = Correlation() - three = Correlation() - four = Correlation() - fish = Stat() - fowl = Stat() - for row in rows: - if row.type == 1: - # yes or no - self.assertIsNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {"yes", "no"}) - one_count += 1 - if row.third_value == "yes": - one_yes_count += 1 - elif row.type == 2: - # positive correlation around 1.4, 1.8 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - two.add2(row.first_value, row.second_value) - elif row.type == 3: - # negative correlation around 11.8, 12.1 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - three.add2(row.first_value, row.second_value) - elif row.type == 4: - # positive correlation around 21.4, 23.4 - self.assertIsNotNone(row.first_value) - self.assertIsNotNone(row.second_value) - self.assertIsNone(row.third_value) - four.add2(row.first_value, row.second_value) - elif row.type == 5: - self.assertIn(row.third_value, {"fish", "fowl"}) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - if row.third_value == "fish": - # mean 8.1 and sd 0.755 - fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - fowl.add(row.first_value) - # type 1 - self.assertAlmostEqual( - one_count, generate_count * 5 / 20, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - one_yes_count / one_count, 0.4, delta=generate_count * 0.4 - ) - # type 2 - self.assertAlmostEqual( - two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 - ) - self.assertAlmostEqual(two.x_mean(), 1.4, delta=0.6) - self.assertAlmostEqual(two.x_var(), 0.315, delta=0.18) - self.assertAlmostEqual(two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(two.y_var(), 0.105, delta=0.06) - self.assertAlmostEqual(two.covar(), 0.105, delta=0.07) - # type 3 - self.assertAlmostEqual( - three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(three.covar(), -2.085, delta=1.1) - # type 4 - self.assertAlmostEqual( - four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(four.covar(), 3.33, delta=1) - # type 5/fish - self.assertAlmostEqual( - fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.6) - # type 5/fowl - self.assertAlmostEqual( - fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) + stats = EavMeasurementTableStats(conn, self.metadata, self) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 2 + self.assertAlmostEqual( + stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) + self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.6) + self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) + self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) + # type 3 + self.assertAlmostEqual( + stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) + # type 4 + self.assertAlmostEqual( + stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: """Test EAV for all columns with sampled and suppressed generation.""" - table_name = "measurement" - table2_name = "observation" generate_count = 800 with self._get_cmd({}) as gc: gc.do_next("measurement.type") @@ -1646,61 +1650,12 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No conn.commit() self.create_data(gc.config, num_passes=generate_count) with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - one_count = 0 - one_yes_count = 0 - fish = Stat() - fowl = Stat() - types: set[int] = set() - for row in rows: - types.add(row.type) - if row.type == 1: - # yes or no - self.assertIsNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {"yes", "no"}) - if row.third_value == "yes": - one_yes_count += 1 - one_count += 1 - elif row.type == 5: - self.assertIn(row.third_value, {"fish", "fowl"}) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - if row.third_value == "fish": - # mean 8.1 and sd 0.755 - fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - fowl.add(row.first_value) - self.assertSubset(types, {1, 2, 3, 4, 5}) - self.assertEqual(len(types), 4) - self.assertSubset({1, 5}, types) - # type 1 - self.assertAlmostEqual( - one_count, generate_count * 5 / 11, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - one_yes_count / one_count, 0.4, delta=generate_count * 0.4 - ) - # type 5/fish - self.assertAlmostEqual( - fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(fish.x_var(), 0.855, delta=0.5) - # type 5/fowl - self.assertAlmostEqual( - fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(fowl.x_var(), 1.86, delta=1) - stmt = select(self.metadata.tables[table2_name]) + stats = EavMeasurementTableStats(conn, self.metadata, self) + stmt = select(self.metadata.tables["observation"]) rows = conn.execute(stmt).fetchall() firsts = Stat() for row in rows: - types.add(row.type) + stats.types.add(row.type) self.assertEqual(row.type, 1) self.assertIsNotNone(row.first_value) self.assertIsNone(row.second_value) @@ -1708,6 +1663,29 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No firsts.add(row.first_value) self.assertEqual(firsts.count(), 800) self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) + self.assert_subset(stats.types, {1, 2, 3, 4, 5}) + self.assertEqual(len(stats.types), 4) + self.assert_subset({1, 5}, stats.types) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) class NonInteractiveTests(RequiresDBTestCase): diff --git a/tests/test_main.py b/tests/test_main.py index a570fe4..e318f1e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -48,6 +48,7 @@ def test_create_vocab( @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators( self, mock_require_stats: MagicMock, @@ -89,6 +90,7 @@ def test_create_generators( @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") @patch("datafaker.main.generators_require_stats") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_uses_default_stats_file_if_necessary( self, mock_require_stats: MagicMock, @@ -151,6 +153,7 @@ def test_create_generators_errors_if_file_exists( @patch("datafaker.main.get_settings") @patch("datafaker.main.Path") @patch("datafaker.main.make_table_generators") + # pylint: disable=too-many-positional-arguments,too-many-arguments def test_create_generators_with_force_enabled( self, mock_make: MagicMock, @@ -371,10 +374,8 @@ def test_make_tables_with_force_enabled( @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.main.get_settings") - @patch("datafaker.main.load_metadata", side_effect=["ms"]) def test_make_stats( self, - _lm: MagicMock, mock_get_settings: MagicMock, mock_make: MagicMock, mock_path: MagicMock, @@ -398,7 +399,7 @@ def test_make_stats( with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) mock_make.assert_called_once_with( - get_test_settings().src_dsn, config, "ms", None + get_test_settings().src_dsn, config, None ) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" @@ -452,12 +453,10 @@ def test_make_stats_errors_if_no_src_dsn(self, mock_logger: MagicMock) -> None: @patch("datafaker.main.Path") @patch("datafaker.main.make_src_stats") @patch("datafaker.main.get_settings") - @patch("datafaker.main.load_metadata") def test_make_stats_with_force_enabled( self, - mock_meta: MagicMock, mock_get_settings: MagicMock, - mock_make: MagicMock, + mock_make_src_stats: MagicMock, mock_path: MagicMock, ) -> None: """Tests that the make-stats command overwrite files when instructed.""" @@ -469,7 +468,7 @@ def test_make_stats_with_force_enabled( test_settings: Settings = get_test_settings() mock_get_settings.return_value = test_settings make_test_output: dict = {"some_stat": 0} - mock_make.return_value = make_test_output + mock_make_src_stats.return_value = make_test_output for force_option in ["--force", "-f"]: with self.subTest(f"Using option {force_option}"): @@ -479,23 +478,21 @@ def test_make_stats_with_force_enabled( "make-stats", "--stats-file=stats_file.yaml", f"--config-file={test_config_file}", - "--orm-file=tests/examples/example_config.yaml", force_option, ], ) - mock_make.assert_called_once_with( + mock_make_src_stats.assert_called_once_with( test_settings.src_dsn, config_file_content, - mock_meta.return_value, - None, + test_settings.src_schema, ) mock_path.return_value.write_text.assert_called_once_with( "some_stat: 0\n", encoding="utf-8" ) self.assertSuccess(result) - mock_make.reset_mock() + mock_make_src_stats.reset_mock() mock_path.reset_mock() def test_validate_config(self) -> None: diff --git a/tests/test_make.py b/tests/test_make.py index b522778..49bb9e7 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -170,14 +170,14 @@ def check_make_stats_output(self, src_stats: dict) -> None: def test_make_stats_no_asyncio_schema(self) -> None: """Test that make_src_stats works when explicitly naming a schema.""" src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, self.config, self.metadata, self.schema_name) + make_src_stats(self.dsn, self.config, self.schema_name) ) self.check_make_stats_output(src_stats) def test_make_stats_no_asyncio(self) -> None: """Test that make_src_stats works using the example configuration.""" src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, self.config, self.metadata, self.schema_name) + make_src_stats(self.dsn, self.config, self.schema_name) ) self.check_make_stats_output(src_stats) @@ -189,7 +189,7 @@ def test_make_stats_asyncio(self) -> None: asyncio.set_event_loop(loop) config_asyncio = {**self.config, "use-asyncio": True} src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, config_asyncio, self.metadata, self.schema_name) + make_src_stats(self.dsn, config_asyncio, self.schema_name) ) self.check_make_stats_output(src_stats) @@ -220,7 +220,7 @@ def test_make_stats_empty_result(self, mock_logger: MagicMock) -> None: ] } src_stats = asyncio.get_event_loop().run_until_complete( - make_src_stats(self.dsn, config, self.metadata, self.schema_name) + make_src_stats(self.dsn, config, self.schema_name) ) self.assertEqual(src_stats[query_name1]["results"], []) self.assertEqual(src_stats[query_name2]["results"], []) diff --git a/tests/test_providers.py b/tests/test_providers.py index b543783..cd88007 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,9 +1,8 @@ """Tests for the providers module.""" import datetime as dt -from pathlib import Path from typing import Any -from sqlalchemy import Column, Integer, Text, create_engine, insert +from sqlalchemy import Column, Integer, Text, insert from sqlalchemy.ext.declarative import declarative_base from datafaker import providers diff --git a/tests/test_remove.py b/tests/test_remove.py index a6dbb85..0d466db 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -20,6 +20,7 @@ class RemoveThingsTestCase(RequiresDBTestCase): def count_rows(self, connection: Connection, table_name: str) -> int | None: """Count the rows in a table.""" return connection.execute( + # pylint: disable=not-callable. select(func.count()).select_from(self.metadata.tables[table_name]) ).scalar() @@ -40,8 +41,8 @@ def test_remove_data(self, mock_get_settings: MagicMock) -> None: }, ) with self.sync_engine.connect() as conn: - self.assertGreaterAndNotNone(self.count_rows(conn, "manufacturer"), 0) - self.assertGreaterAndNotNone(self.count_rows(conn, "model"), 0) + self.assert_greater_and_not_none(self.count_rows(conn, "manufacturer"), 0) + self.assert_greater_and_not_none(self.count_rows(conn, "model"), 0) self.assertEqual(self.count_rows(conn, "player"), 0) self.assertEqual(self.count_rows(conn, "string"), 0) self.assertEqual(self.count_rows(conn, "signature_model"), 0) diff --git a/tests/test_rst.py b/tests/test_rst.py index 29bf971..1a57ed6 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -55,7 +55,7 @@ def test_dir(self) -> None: for file_error in file_errors # Only worry about ERRORs and WARNINGs if file_error.level <= 2 - if not any(filter(lambda m: m in file_error.full_message, allowed_errors)) + if not any(m in file_error.full_message for m in allowed_errors) ] if filtered_errors: diff --git a/tests/test_unique_generator.py b/tests/test_unique_generator.py index 503a36f..afec078 100644 --- a/tests/test_unique_generator.py +++ b/tests/test_unique_generator.py @@ -1,5 +1,4 @@ """Tests for the unique_generator module.""" -from pathlib import Path from unittest.mock import MagicMock from sqlalchemy import Boolean, Column, Integer, Text, UniqueConstraint, insert diff --git a/tests/utils.py b/tests/utils.py index 4e9f236..78df87c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,7 +12,6 @@ import testing.postgresql import yaml -from sqlalchemy import MetaData from sqlalchemy.schema import MetaData from datafaker import settings @@ -80,7 +79,7 @@ def assertNoException(self, result: Any) -> None: # pylint: disable=invalid-nam return self.fail("".join(traceback.format_exception(result.exception))) - def assertGreaterAndNotNone(self, left: float | None, right: float) -> None: + def assert_greater_and_not_none(self, left: float | None, right: float) -> None: """ Assert left is not None and greater than right """ @@ -89,7 +88,7 @@ def assertGreaterAndNotNone(self, left: float | None, right: float) -> None: else: self.assertGreater(left, right) - def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: + def assert_subset(self, set1: set[T], set2: set[T], msg: str | None = None) -> None: """Assert a set is a (non-strict) subset. :param set1: The asserted subset. @@ -100,9 +99,9 @@ def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> No try: difference = set1.difference(set2) except TypeError as e: - self.fail("invalid type when attempting set difference: %s" % e) + self.fail(f"invalid type when attempting set difference: {e}") except AttributeError as e: - self.fail("first argument does not support set difference: %s" % e) + self.fail(f"first argument does not support set difference: {e}") if not difference: return @@ -113,8 +112,8 @@ def assertSubset(self, set1: set[T], set2: set[T], msg: str | None = None) -> No for item in difference: lines.append(repr(item)) - standardMsg = "\n".join(lines) - self.fail(self._formatMessage(msg, standardMsg)) + standard_msg = "\n".join(lines) + self.fail(self._formatMessage(msg, standard_msg)) @skipUnless(shutil.which("psql"), "need to find 'psql': install PostgreSQL to enable") @@ -148,7 +147,7 @@ def tearDownClass(cls) -> None: def setUp(self) -> None: super().setUp() assert self.Postgresql is not None - self.postgresql = self.Postgresql() + self.postgresql = self.Postgresql() # pylint: disable=not-callable if self.dump_file_path is not None: self.run_psql(Path(self.examples_dir) / Path(self.dump_file_path)) self.engine = create_db_engine( @@ -166,11 +165,12 @@ def tearDown(self) -> None: @property def dsn(self) -> str: + """Get the database connection string.""" if self.database_name: url = self.postgresql.url(database=self.database_name) else: url = self.postgresql.url() - assert type(url) is str + assert isinstance(url, str) return url def run_psql(self, dump_file: Path) -> None: @@ -199,7 +199,19 @@ def run_psql(self, dump_file: Path) -> None: class GeneratesDBTestCase(RequiresDBTestCase): + """A test case for which a database is generated.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialise a GeneratedDB test case.""" + super().__init__(*args, **kwargs) + self.generators_file_path = "" + self.stats_fd = 0 + self.stats_file_path = "" + self.config_file_path = "" + self.config_fd = 0 + def setUp(self) -> None: + """Set up the test case with an actual orm.yaml file.""" super().setUp() # Generate the `orm.yaml` from the database (self.orm_fd, self.orm_file_path) = mkstemp(".yaml", "orm_", text=True) @@ -207,21 +219,20 @@ def setUp(self) -> None: orm_fh.write(make_tables_file(self.dsn, self.schema_name, {})) def set_configuration(self, config: Mapping[str, Any]) -> None: - """ - Accepts a configuration file, writes it out. - """ + """Accepts a configuration file, writes it out.""" (self.config_fd, self.config_file_path) = mkstemp(".yaml", "config_", text=True) with os.fdopen(self.config_fd, "w", encoding="utf-8") as config_fh: config_fh.write(yaml.dump(config)) def get_src_stats(self, config: Mapping[str, Any]) -> dict[str, Any]: """ - Runs `make-stats` producing `src-stats.yaml` + Runs `make-stats` producing `src-stats.yaml`. + :return: Python dictionary representation of the contents of the src-stats file """ loop = asyncio.new_event_loop() src_stats = loop.run_until_complete( - make_src_stats(self.dsn, config, self.metadata, self.schema_name) + make_src_stats(self.dsn, config, self.schema_name) ) loop.close() (self.stats_fd, self.stats_file_path) = mkstemp( From 2894044f660f2cdc01ce9a735b47ff20b4ad6f80 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 13 Oct 2025 18:56:46 +0100 Subject: [PATCH 20/35] Precommit clean! --- .pre-commit-config.yaml | 1 + datafaker/base.py | 53 +- datafaker/create.py | 19 +- datafaker/generators.py | 2160 ---------------- datafaker/generators/__init__.py | 53 + datafaker/generators/base.py | 417 ++++ datafaker/generators/choice.py | 398 +++ datafaker/generators/continuous.py | 471 ++++ datafaker/generators/mimesis.py | 418 ++++ datafaker/generators/partitioned.py | 514 ++++ datafaker/interactive.py | 2175 ----------------- datafaker/interactive/__init__.py | 95 + datafaker/interactive/base.py | 404 +++ datafaker/interactive/generators.py | 980 ++++++++ datafaker/interactive/missingness.py | 355 +++ datafaker/interactive/table.py | 376 +++ datafaker/main.py | 5 +- datafaker/utils.py | 11 +- ...tive.py => test_interactive_generators.py} | 539 +--- tests/test_interactive_missingness.py | 100 + tests/test_interactive_table.py | 398 +++ tests/test_main.py | 4 +- tests/utils.py | 49 +- 23 files changed, 5090 insertions(+), 4905 deletions(-) delete mode 100644 datafaker/generators.py create mode 100644 datafaker/generators/__init__.py create mode 100644 datafaker/generators/base.py create mode 100644 datafaker/generators/choice.py create mode 100644 datafaker/generators/continuous.py create mode 100644 datafaker/generators/mimesis.py create mode 100644 datafaker/generators/partitioned.py delete mode 100644 datafaker/interactive.py create mode 100644 datafaker/interactive/__init__.py create mode 100644 datafaker/interactive/base.py create mode 100644 datafaker/interactive/generators.py create mode 100644 datafaker/interactive/missingness.py create mode 100644 datafaker/interactive/table.py rename tests/{test_interactive.py => test_interactive_generators.py} (70%) create mode 100644 tests/test_interactive_missingness.py create mode 100644 tests/test_interactive_table.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7eba811..04464f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,6 +6,7 @@ repos: rev: v4.2.0 hooks: - id: trailing-whitespace + exclude: docs/(source|build/html)/_static/ - id: end-of-file-fixer exclude: docs/source/_static/ - id: check-yaml diff --git a/datafaker/base.py b/datafaker/base.py index 6ff1890..f75591c 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -26,6 +26,10 @@ ) +class InappropriateGeneratorException(Exception): + """Exception thrown if a generator is requested that is not appropriate.""" + + @functools.cache def zipf_weights(size: int) -> list[float]: """Get the weights of a Zipf distribution of a given size.""" @@ -122,19 +126,26 @@ def lognormal(self, logmean: float, logsd: float) -> float: """ return random.lognormvariate(float(logmean), float(logsd)) - def choice(self, a: list[T]) -> T: + def choice_direct(self, a: list[T]) -> T: + """ + Choose a value with equal probability. + + :param a: The list of values to output. + :return: The chosen value. + """ + return random.choice(a) + + def choice(self, a: list[Mapping[str, T]]) -> T | None: """ Choose a value with equal probability. - :param a: The list of values to output. Each element is either - the value itself, or a mapping with a key ``value`` and the key - is the value to return. + :param a: The list of values to output. Each element is a mapping with + a key ``value`` and the key is the value to return. :return: The chosen value. """ - c = random.choice(a) - return c["value"] if isinstance(c, Mapping) and "value" in c else c + return self.choice_direct(a).get("value", None) - def zipf_choice(self, a: list[T], n: int | None = None) -> T: + def zipf_choice_direct(self, a: list[T], n: int | None = None) -> T: """ Choose a value according to the Zipf distribution. @@ -142,14 +153,26 @@ def zipf_choice(self, a: list[T], n: int | None = None) -> T: 1/n times as frequently as the first value is chosen. :param a: The list of values to output, most frequent first. - Each element is either the value itself, or a mapping with - a key ``value`` and the key is the value to return. :return: The chosen value. """ if n is None: n = len(a) - c = random.choices(a, weights=zipf_weights(n))[0] - return c["value"] if isinstance(c, Mapping) and "value" in c else c + return random.choices(a, weights=zipf_weights(n))[0] + + def zipf_choice(self, a: list[Mapping[str, T]], n: int | None = None) -> T | None: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of rows to choose between, most frequent first. + Each element is a mapping with a key ``value`` and the key is the + value to return. + :return: The chosen value. + """ + c = self.zipf_choice_direct(a, n) + return c.get("value", None) def weighted_choice(self, a: list[dict[str, Any]]) -> Any: """ @@ -214,7 +237,9 @@ def _select_group(self, alts: list[dict[str, Any]]) -> Any: choice -= alt["count"] if choice < 0: return alt - raise Exception("Internal error: ran out of choices in _select_group") + raise NothingToGenerateException( + "Internal error: ran out of choices in _select_group" + ) def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: """ @@ -286,7 +311,9 @@ def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any def _check_generator_name(self, name: str) -> None: if name not in self.PERMITTED_SUBGENS: - raise Exception("%s is not a permitted generator", name) + raise InappropriateGeneratorException( + f"{name} is not a permitted generator" + ) def alternatives( self, diff --git a/datafaker/create.py b/datafaker/create.py index ce2a74e..a877320 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -1,6 +1,7 @@ """Functions and classes to create and populate the target database.""" import pathlib from collections import Counter +from types import ModuleType from typing import Any, Generator, Iterable, Iterator, Mapping, Sequence, Tuple from sqlalchemy import Connection, insert, inspect @@ -97,8 +98,7 @@ def create_db_vocab( def create_db_data( sorted_tables: Sequence[Table], - table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], + df_module: ModuleType, num_passes: int, ) -> RowCounts: """Connect to a database and populate it with data.""" @@ -108,8 +108,7 @@ def create_db_data( return create_db_data_into( sorted_tables, - table_generator_dict, - story_generator_list, + df_module, num_passes, dst_dsn, settings.dst_schema, @@ -118,8 +117,7 @@ def create_db_data( def create_db_data_into( sorted_tables: Sequence[Table], - table_generator_dict: Mapping[str, TableGenerator], - story_generator_list: Sequence[Mapping[str, Any]], + df_module: ModuleType, num_passes: int, db_dsn: str, schema_name: str | None, @@ -145,12 +143,13 @@ def create_db_data_into( row_counts += populate( dst_conn, sorted_tables, - table_generator_dict, - story_generator_list, + df_module.table_generator_dict, + df_module.story_generator_list, ) return row_counts +# pylint: disable=too-many-instance-attributes class StoryIterator: """Iterates through all the rows produced by all the stories.""" @@ -305,7 +304,9 @@ def populate( story_iterator.insert() t = story_iterator.table_name() if t is None: - raise Exception("Internal error") + raise AssertionError( + "Internal error: story iterator returns None but not is_ended" + ) row_counts[t] = row_counts.get(t, 0) + 1 story_iterator.next() diff --git a/datafaker/generators.py b/datafaker/generators.py deleted file mode 100644 index 0b760e4..0000000 --- a/datafaker/generators.py +++ /dev/null @@ -1,2160 +0,0 @@ -"""Generator factories for making generators for single columns.""" - -import decimal -import math -import re -import typing -from abc import ABC, abstractmethod -from collections.abc import Mapping -from dataclasses import dataclass -from functools import lru_cache -from itertools import chain, combinations -from typing import Any, Callable, Iterable, Sequence, Union - -import mimesis -import mimesis.locales -import sqlalchemy -from sqlalchemy import Column, Connection, CursorResult, Engine, RowMapping, text -from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time, TypeEngine -from typing_extensions import Self - -from datafaker.base import DistributionGenerator -from datafaker.utils import T, logger - -NumericType = Union[int, float] - -# How many distinct values can we have before we consider a -# choice distribution to be infeasible? -MAXIMUM_CHOICES = 500 - -dist_gen = DistributionGenerator() -generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) - - -class Generator(ABC): - """ - Random data generator. - - A generator is specific to a particular column in a particular table in - a particluar database. - - A generator knows how to fetch its summary data from the database, how to calculate - its fit (if apropriate) and which function actually does the generation. - - It also knows these summary statistics for the column it was instantiated on, - and therefore knows how to generate fake data for that column. - """ - - @abstractmethod - def function_name(self) -> str: - """Get the name of the generator function to put into df.py.""" - - def name(self) -> str: - """ - Get the name of the generator. - - Usually the same as the function name, but can be different to distinguish - between generators that have the same function but different queries. - """ - return self.function_name() - - @abstractmethod - def nominal_kwargs(self) -> dict[str, str]: - """ - Get the kwargs the generator wants to be called with. - - The values will tend to be references to something in the src-stats.yaml - file. - For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will - provide the value stored in src-stats.yaml as - SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument - to the generator function. - """ - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """ - Get the SQL clauses to add to a SELECT ... FROM {table} query. - - Will add to SRC_STATS["auto__{table}"] - For example { - "count": { - "clause": "COUNT(*)", - "comment": "number of rows in table {table}" - }, "avg_thiscolumn": { - "clause": "AVG(thiscolumn)", - "comment": "Average value of thiscolumn in table {table}" - }} - will make the clause become: - "SELECT COUNT(*) AS count, AVG(thiscolumn) AS avg_thiscolumn FROM thistable" - and this will populate SRC_STATS["auto__thistable"]["results"][0]["count"] and - SRC_STATS["auto__thistable"]["results"][0]["avg_thiscolumn"] in the src-stats.yaml file. - """ - return {} - - def custom_queries(self) -> dict[str, dict[str, str]]: - """ - Get the SQL queries to add to SRC_STATS. - - Should be used for queries that do not follow the SELECT ... FROM table format - using aggregate queries, because these should use select_aggregate_clauses. - - For example {"myquery": { - "query": "SELECT one, too AS two FROM mytable WHERE too > 1", - "comment": "big enough one and two from table mytable" - }} - will populate SRC_STATS["myquery"]["results"][0]["one"] - and SRC_STATS["myquery"]["results"][0]["two"] - in the src-stats.yaml file. - - Keys should be chosen to minimize the chances of clashing with other queries, - for example "auto__{table}__{column}__{queryname}" - """ - return {} - - @abstractmethod - def actual_kwargs(self) -> dict[str, Any]: - """ - Get the kwargs (summary statistics) this generator is instantiated with. - - This must match `nominal_kwargs` in structure. - """ - - @abstractmethod - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - - def fit(self, default: float = -1) -> float: - """ - Return a value representing how well the distribution fits the real source data. - - 0.0 means "perfectly". - Returns default if no fitness has been defined. - """ - return default - - -class PredefinedGenerator(Generator): - """Generator built from an existing config.yaml.""" - - SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") - AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") - SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') - - def _get_src_stats_mentioned(self, val: Any) -> set[str]: - if not val: - return set() - if isinstance(val, str): - ss = self.SRC_STAT_NAME_RE.match(val) - if ss: - ss_name = ss.group(1) - logger.debug("Found SRC_STATS reference %s", ss_name) - return set([ss_name]) - logger.debug("Value %s does not seem to be a SRC_STATS reference", val) - return set() - if isinstance(val, list): - return set.union(*(self._get_src_stats_mentioned(v) for v in val)) - if isinstance(val, dict): - return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) - return set() - - def __init__( - self, - table_name: str, - generator_object: Mapping[str, Any], - config: Mapping[str, Any], - ): - """ - Initialise a generator from a config.yaml. - - :param config: The entire configuration. - :param generator_object: The part of the configuration at tables.*.row_generators - """ - logger.debug( - "Creating a PredefinedGenerator %s from table %s", - generator_object["name"], - table_name, - ) - self._table_name = table_name - self._name: str = generator_object["name"] - self._kwn: dict[str, str] = generator_object.get("kwargs", {}) - self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) - # Need to deal with this somehow (or remove it from the schema) - self._argn: list[str] = generator_object.get("args", []) - self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} - self._custom_queries = {} - for sstat in config.get("src-stats", []): - name: str = sstat["name"] - dpq = sstat.get("dp-query", None) - query = sstat.get( - "query", dpq - ) # ... should we really be combining query and dp-query? - comments = sstat.get("comments", []) - if name in self._src_stats_mentioned: - logger.debug("Found a src-stats entry for %s", name) - # This query is one that this generator is interested in - sam = None if query is None else self.SELECT_AGGREGATE_RE.match(query) - # sam.group(2) is the table name from the FROM clause of the query - if sam and name == f"auto__{sam.group(2)}": - # name is auto__{table_name}, so it's a select_aggregate, so we split up its clauses - sacs = [ - self.AS_CLAUSE_RE.match(clause) - for clause in sam.group(1).split(",") - ] - # Work out what select_aggregate_clauses this represents - for sac in sacs: - if sac is not None: - comment = comments.pop() if comments else None - self._select_aggregate_clauses[sac.group(2)] = { - "clause": sac.group(1), - "comment": comment, - } - else: - # some other name, so must be a custom query - logger.debug("Custom query %s is '%s'", name, query) - self._custom_queries[name] = { - "query": query, - "comment": comments[0] if comments else None, - } - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return self._name - - def nominal_kwargs(self) -> dict[str, str]: - """Get the arguments to be entered into ``config.yaml``.""" - return self._kwn - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - return self._select_aggregate_clauses - - def custom_queries(self) -> dict[str, dict[str, str]]: - """Get the queries the generators need to call.""" - return self._custom_queries - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - # Run the queries from nominal_kwargs - # ... - logger.error("PredefinedGenerator.actual_kwargs not implemented yet") - return {} - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - # Call the function if we can. This could be tricky... - # ... - logger.error("PredefinedGenerator.generate_data not implemented yet") - return [] - - -class GeneratorFactory(ABC): - """A factory for making generators appropriate for a database column.""" - - @abstractmethod - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - - -class Buckets: - """ - Measured buckets for a real distribution. - - Finds the real distribution of continuous data so that we can measure - the fit of generators against it. - """ - - def __init__( - self, - engine: Engine, - table_name: str, - column_name: str, - mean: float, - stddev: float, - count: int, - ): - """Initialise a Buckets object.""" - with engine.connect() as connection: - raw_buckets = connection.execute( - text( - f"SELECT COUNT({column_name}) AS f," - f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" - f" FROM {table_name} GROUP BY b" - ) - ) - self.buckets: Sequence[int] = [0] * 10 - for rb in raw_buckets: - if rb.b is not None: - bucket = min(9, max(0, int(rb.b) + 1)) - self.buckets[bucket] += rb.f / count - self.mean = mean - self.stddev = stddev - - @classmethod - def make_buckets( - cls, engine: Engine, table_name: str, column_name: str - ) -> Self | None: - """ - Construct a Buckets object. - - Calculates the mean and standard deviation of the values in the column - specified and makes ten buckets, centered on the mean and each half - a standard deviation wide (except for the end two that extend to - infinity). Each bucket will be set to the count of the number of values - in the column within that bucket. - """ - with engine.connect() as connection: - result = connection.execute( - text( - f"SELECT AVG({column_name}) AS mean," - f" STDDEV({column_name}) AS stddev," - f" COUNT({column_name}) AS count FROM {table_name}" - ) - ).first() - if result is None or result.stddev is None or getattr(result, "count") < 2: - return None - try: - buckets = cls( - engine, - table_name, - column_name, - result.mean, - result.stddev, - getattr(result, "count"), - ) - except sqlalchemy.exc.DatabaseError as exc: - logger.debug("Failed to instantiate Buckets object: %s", exc) - return None - return buckets - - def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: - """Figure out the fit from bucket counts from the generator distribution.""" - return fit_from_buckets(self.buckets, bucket_counts) - - def fit_from_values(self, values: list[float]) -> float: - """Figure out the fit from samples from the generator distribution.""" - buckets = [0] * 10 - x = self.mean - 2 * self.stddev - w = self.stddev / 2 - for v in values: - b = min(9, max(0, int((v - x) / w))) - buckets[b] += 1 - return self.fit_from_counts(buckets) - - -class MultiGeneratorFactory(GeneratorFactory): - """A composite factory.""" - - def __init__(self, factories: list[GeneratorFactory]): - """Initialise a MultiGeneratorFactory.""" - super().__init__() - self.factories = factories - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - return [ - generator - for factory in self.factories - for generator in factory.get_generators(columns, engine) - ] - - -class MimesisGeneratorBase(Generator): - """Base class for a generator using Mimesis.""" - - def __init__( - self, - function_name: str, - ): - """ - Initialise a generator that uses Mimesis. - - :param function_name: is relative to 'generic', for example 'person.name'. - """ - super().__init__() - f = generic - for part in function_name.split("."): - if not hasattr(f, part): - raise Exception( - f"Mimesis does not have a function {function_name}: {part} not found" - ) - f = getattr(f, part) - if not callable(f): - raise Exception( - f"Mimesis object {function_name} is not a callable," - " so cannot be used as a generator" - ) - self._name = "generic." + function_name - self._generator_function = f - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return self._name - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [self._generator_function() for _ in range(count)] - - -class MimesisGenerator(MimesisGeneratorBase): - """A generator using Mimesis.""" - - def __init__( - self, - function_name: str, - value_fn: Callable[[Any], float] | None = None, - buckets: Buckets | None = None, - ): - """ - Initialise a generator using Mimesis. - - :param function_name: is relative to 'generic', for example 'person.name'. - :param value_fn: Function to convert generator output to floats, if needed. The values - thus produced are compared against the buckets to estimate the fit. - :param buckets: The distribution of string lengths in the real data. If this is None - then the fit method will return None. - """ - super().__init__(function_name) - if buckets is None: - self._fit = None - return - samples = self.generate_data(400) - if value_fn: - samples = [value_fn(s) for s in samples] - self._fit = buckets.fit_from_values(samples) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return self._name - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return {} - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return {} - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default if self._fit is None else self._fit - - -class MimesisGeneratorTruncated(MimesisGenerator): - """A string generator using Mimesis that must fit within a certain number of characters.""" - - def __init__( - self, - function_name: str, - length: int, - value_fn: Callable[[Any], float] | None = None, - buckets: Buckets | None = None, - ): - """Initialise a MimesisGeneratorTruncated.""" - self._length = length - super().__init__(function_name, value_fn, buckets) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.truncated_string" - - def name(self) -> str: - """Get the name of the generator.""" - return f"{self._name} [truncated to {self._length}]" - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "subgen_fn": self._name, - "params": {}, - "length": self._length, - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "subgen_fn": self._name, - "params": {}, - "length": self._length, - } - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [self._generator_function()[: self._length] for _ in range(count)] - - -class MimesisDateTimeGenerator(MimesisGeneratorBase): - """DateTime generator using Mimesis.""" - - def __init__( - self, - column: Column, - function_name: str, - min_year: str, - max_year: str, - start: int, - end: int, - ) -> None: - """ - Initialise a MimesisDateTimeGenerator. - - :param column: The column to generate into - :param function_name: The name of the mimesis function - :param min_year: SQL expression extracting the minimum year - :param min_year: SQL expression extracting the maximum year - :param start: The actual first year found - :param end: The actual last year found - """ - super().__init__(function_name) - self._column = column - self._max_year = max_year - self._min_year = min_year - self._start = start - self._end = end - - @classmethod - def make_singleton( - cls, column: Column, engine: Engine, function_name: str - ) -> Sequence[Generator]: - """Make the appropriate generation configuration for this column.""" - extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" - max_year = f"MAX({extract_year})" - min_year = f"MIN({extract_year})" - with engine.connect() as connection: - result = connection.execute( - text( - f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" - ) - ).first() - if result is None or result.start is None or result.end is None: - return [] - return [ - MimesisDateTimeGenerator( - column, - function_name, - min_year, - max_year, - int(result.start), - int(result.end), - ) - ] - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "start": ( - f'SRC_STATS["auto__{self._column.table.name}"]["results"]' - f'[0]["{self._column.name}__start"]' - ), - "end": ( - f'SRC_STATS["auto__{self._column.table.name}"]["results"]' - f'[0]["{self._column.name}__end"]' - ), - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "start": self._start, - "end": self._end, - } - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - return { - f"{self._column.name}__start": { - "clause": self._min_year, - "comment": ( - f"Earliest year found for column {self._column.name}" - f" in table {self._column.table.name}" - ), - }, - f"{self._column.name}__end": { - "clause": self._max_year, - "comment": ( - f"Latest year found for column {self._column.name}" - f" in table {self._column.table.name}" - ), - }, - } - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [ - self._generator_function(start=self._start, end=self._end) - for _ in range(count) - ] - - -def get_column_type(column: Column) -> TypeEngine: - """Get the type of the column, generic if possible.""" - try: - return column.type.as_generic() - except NotImplementedError: - return column.type - - -class MimesisStringGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return strings.""" - - GENERATOR_NAMES = [ - "address.calling_code", - "address.city", - "address.continent", - "address.country", - "address.country_code", - "address.postal_code", - "address.province", - "address.street_number", - "address.street_name", - "address.street_suffix", - "person.blood_type", - "person.email", - "person.first_name", - "person.last_name", - "person.full_name", - "person.gender", - "person.language", - "person.nationality", - "person.occupation", - "person.password", - "person.title", - "person.university", - "person.username", - "person.worldview", - "text.answer", - "text.color", - "text.level", - "text.quote", - "text.sentence", - "text.text", - "text.word", - ] - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - column_type = get_column_type(column) - if not isinstance(column_type, String): - return [] - try: - buckets = Buckets.make_buckets( - engine, - column.table.name, - f"LENGTH({column.name})", - ) - fitness_fn = len - except Exception: - # Some column types that appear to be strings (such as enums) - # cannot have their lengths measured. In this case we cannot - # detect fitness using lengths. - buckets = None - fitness_fn = None - length = column_type.length - if length: - return list( - map( - lambda gen: MimesisGeneratorTruncated( - gen, length, fitness_fn, buckets - ), - self.GENERATOR_NAMES, - ) - ) - return list( - map( - lambda gen: MimesisGenerator(gen, fitness_fn, buckets), - self.GENERATOR_NAMES, - ) - ) - - -class MimesisFloatGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return floating point numbers.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - if not isinstance(get_column_type(column), Numeric): - return [] - return list( - map( - MimesisGenerator, - [ - "person.height", - ], - ) - ) - - -class MimesisDateGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return dates.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Date): - return [] - return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") - - -class MimesisDateTimeGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return datetimes.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, DateTime): - return [] - return MimesisDateTimeGenerator.make_singleton( - column, engine, "datetime.datetime" - ) - - -class MimesisTimeGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return times.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Time): - return [] - return [MimesisGenerator("datetime.time")] - - -class MimesisIntegerGeneratorFactory(GeneratorFactory): - """All Mimesis generators that return integers.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - return [MimesisGenerator("person.weight")] - - -def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: - """Calculate the fit by comparing a pair of lists of buckets.""" - sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) - count = len(ys) - return sum_diff_squared / (count * count) - - -class ContinuousDistributionGenerator(Generator): - """Base class for generators producing continuous distributions.""" - - expected_buckets: Sequence[NumericType] = [] - - def __init__(self, table_name: str, column_name: str, buckets: Buckets): - """Initialise a ContinuousDistributionGenerator.""" - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.buckets = buckets - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "mean": ( - f'SRC_STATS["auto__{self.table_name}"]["results"]' - f'[0]["mean__{self.column_name}"]' - ), - "sd": ( - f'SRC_STATS["auto__{self.table_name}"]["results"]' - f'[0]["stddev__{self.column_name}"]' - ), - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - if self.buckets is None: - return {} - return { - "mean": self.buckets.mean, - "sd": self.buckets.stddev, - } - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - clauses = super().select_aggregate_clauses() - return { - **clauses, - f"mean__{self.column_name}": { - "clause": f"AVG({self.column_name})", - "comment": f"Mean of {self.column_name} from table {self.table_name}", - }, - f"stddev__{self.column_name}": { - "clause": f"STDDEV({self.column_name})", - "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", - }, - } - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - if self.buckets is None: - return default - return self.buckets.fit_from_counts(self.expected_buckets) - - -class GaussianGenerator(ContinuousDistributionGenerator): - """Generator producing numbers in a Gaussian (normal) distribution.""" - - expected_buckets = [ - 0.0227, - 0.0441, - 0.0918, - 0.1499, - 0.1915, - 0.1915, - 0.1499, - 0.0918, - 0.0441, - 0.0227, - ] - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.normal" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [ - dist_gen.normal(self.buckets.mean, self.buckets.stddev) - for _ in range(count) - ] - - -class UniformGenerator(ContinuousDistributionGenerator): - """Generator producing numbers in a uniform distribution.""" - - expected_buckets = [ - 0, - 0.06698, - 0.14434, - 0.14434, - 0.14434, - 0.14434, - 0.14434, - 0.14434, - 0.06698, - 0, - ] - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.uniform_ms" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [ - dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) - for _ in range(count) - ] - - -class ContinuousDistributionGeneratorFactory(GeneratorFactory): - """All generators that want an average and standard deviation.""" - - def _get_generators_from_buckets( - self, - _engine: Engine, - table_name: str, - column_name: str, - buckets: Buckets, - ) -> Sequence[Generator]: - return [ - GaussianGenerator(table_name, column_name, buckets), - UniformGenerator(table_name, column_name, buckets), - ] - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - ct = get_column_type(column) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - column_name = column.name - table_name = column.table.name - buckets = Buckets.make_buckets(engine, table_name, column_name) - if buckets is None: - return [] - return self._get_generators_from_buckets( - engine, table_name, column_name, buckets - ) - - -class LogNormalGenerator(Generator): - """Generator producing numbers in a log-normal distribution.""" - - # TODO: figure out the real buckets here (this was from a random sample in R) - expected_buckets = [ - 0, - 0, - 0, - 0.28627, - 0.40607, - 0.14937, - 0.06735, - 0.03492, - 0.01918, - 0.03684, - ] - - def __init__( - self, - table_name: str, - column_name: str, - buckets: Buckets, - logmean: float, - logstddev: float, - ): - """Initialise a LogNormalGenerator.""" - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.buckets = buckets - self.logmean = logmean - self.logstddev = logstddev - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.lognormal" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "logmean": ( - f'SRC_STATS["auto__{self.table_name}"]["results"][0]' - f'["logmean__{self.column_name}"]' - ), - "logsd": ( - f'SRC_STATS["auto__{self.table_name}"]["results"][0]' - f'["logstddev__{self.column_name}"]' - ), - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "logmean": self.logmean, - "logsd": self.logstddev, - } - - def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: - """Get the query fragments the generators need to call.""" - clauses = super().select_aggregate_clauses() - return { - **clauses, - f"logmean__{self.column_name}": { - "clause": ( - f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" - " ELSE NULL END)" - ), - "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", - }, - f"logstddev__{self.column_name}": { - "clause": ( - f"STDDEV(CASE WHEN 0<{self.column_name}" - f" THEN LN({self.column_name}) ELSE NULL END)" - ), - "comment": ( - f"Standard deviation of logs of {self.column_name}" - f" from table {self.table_name}" - ), - }, - } - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - if self.buckets is None: - return default - return self.buckets.fit_from_counts(self.expected_buckets) - - -class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): - """All generators that want an average and standard deviation of log data.""" - - def _get_generators_from_buckets( - self, - engine: Engine, - table_name: str, - column_name: str, - buckets: Buckets, - ) -> Sequence[Generator]: - with engine.connect() as connection: - result = connection.execute( - text( - f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" - " ELSE NULL END) AS logmean," - f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" - f" AS logstddev FROM {table_name}" - ) - ).first() - if result is None or result.logstddev is None: - return [] - return [ - LogNormalGenerator( - table_name, - column_name, - buckets, - float(result.logmean), - float(result.logstddev), - ) - ] - - -def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: - """ - Get a zipf distribution for a certain number of items. - - :param total: The total number of items to be distributed. - :param bins: The total number of bins to distribute the items into. - :return: A generator of the number of items in each bin, from the - largest to the smallest. - """ - basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) - bd_remaining = sum(basic_dist) - for b in basic_dist: - # yield b/bd_remaining of the `total` remaining - if bd_remaining == 0: - yield 0 - else: - x = math.floor(0.5 + total * b / bd_remaining) - bd_remaining -= x * bd_remaining / total - total -= x - yield x - - -class ChoiceGenerator(Generator): - """Base generator for all generators producing choices of items.""" - - STORE_COUNTS = False - - def __init__( - self, - table_name: str, - column_name: str, - values: list[Any], - counts: list[int], - sample_count: int | None = None, - suppress_count: int = 0, - ) -> None: - """Initialise a ChoiceGenerator.""" - super().__init__() - self.table_name = table_name - self.column_name = column_name - self.values = values - estimated_counts = self.get_estimated_counts(counts) - self._fit = fit_from_buckets(counts, estimated_counts) - - extra_results = "" - extra_expo = "" - extra_comment = "" - if self.STORE_COUNTS: - extra_results = f", COUNT({column_name}) AS count" - extra_expo = ", count" - extra_comment = " and their counts" - if suppress_count == 0: - if sample_count is None: - self._query = ( - f"SELECT {column_name} AS value{extra_results} FROM {table_name}" - f" WHERE {column_name} IS NOT NULL GROUP BY value" - f" ORDER BY COUNT({column_name}) DESC" - ) - self._comment = ( - f"All the values{extra_comment} that appear in column {column_name}" - f" of table {table_name}" - ) - self._annotation = None - else: - self._query = ( - f"SELECT {column_name} AS value{extra_results} FROM" - f" (SELECT {column_name} FROM {table_name}" - f" WHERE {column_name} IS NOT NULL" - f" ORDER BY RANDOM() LIMIT {sample_count})" - f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" - ) - self._comment = ( - f"The values{extra_comment} that appear in column {column_name}" - f" of a random sample of {sample_count} rows of table {table_name}" - ) - self._annotation = "sampled" - else: - if sample_count is None: - self._query = ( - f"SELECT value{extra_expo} FROM" - f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" - f" FROM {table_name} WHERE {column_name} IS NOT NULL" - f" GROUP BY value ORDER BY count DESC) AS _inner" - f" WHERE {suppress_count} < count" - ) - self._comment = ( - f"All the values{extra_comment} that appear in column {column_name}" - f" of table {table_name} more than {suppress_count} times" - ) - self._annotation = "suppressed" - else: - self._query = ( - f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" - f" (SELECT {column_name} AS value FROM {table_name}" - f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" - f" AS _inner GROUP BY value ORDER BY count DESC)" - f" AS _inner WHERE {suppress_count} < count" - ) - self._comment = ( - f"The values{extra_comment} that appear more than {suppress_count} times" - f" in column {column_name}, out of a random sample of {sample_count} rows" - f" of table {table_name}" - ) - self._annotation = "sampled and suppressed" - - @abstractmethod - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', - } - - def name(self) -> str: - """Get the name of the generator.""" - n = super().name() - if self._annotation is None: - return n - return f"{n} [{self._annotation}]" - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "a": self.values, - } - - def custom_queries(self) -> dict[str, dict[str, str]]: - """Get the queries the generators need to call.""" - qs = super().custom_queries() - return { - **qs, - f"auto__{self.table_name}__{self.column_name}": { - "query": self._query, - "comment": self._comment, - }, - } - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default if self._fit is None else self._fit - - -class ZipfChoiceGenerator(ChoiceGenerator): - """Generator producing items in a Zipf distribution.""" - - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - return list(zipf_distribution(sum(counts), len(counts))) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.zipf_choice" - - def generate_data(self, count: int) -> list[float]: - """Generate ``count`` random data points for this column.""" - return [ - dist_gen.zipf_choice(self.values, len(self.values)) for _ in range(count) - ] - - -def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: - """ - Construct a distribution putting ``total`` items uniformly into ``bins`` bins. - - If they don't fit exactly evenly, the earlier bins will have one more - item than the later bins so the total is as required. - """ - p = total // bins - n = total % bins - for _ in range(0, n): - yield p + 1 - for _ in range(n, bins): - yield p - - -class UniformChoiceGenerator(ChoiceGenerator): - """A generator producing values, each roughly as frequently as each other.""" - - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - return list(uniform_distribution(sum(counts), len(counts))) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.choice" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [dist_gen.choice(self.values) for _ in range(count)] - - -class WeightedChoiceGenerator(ChoiceGenerator): - """Choice generator that matches the source data's frequency.""" - - STORE_COUNTS = True - - def get_estimated_counts(self, counts: list[int]) -> list[int]: - """Get the counts that we would expect if this distribution was the correct one.""" - return counts - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.weighted_choice" - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [dist_gen.weighted_choice(self.values) for _ in range(count)] - - -class ValueGatherer: - """ - Gathers values from a query of values and counts. - - The query must return columns ``v`` for a value and ``f`` for the - count of how many of those values there are. - These values will be gathered into a number of properties: - ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts - in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and - ``count`` giving these values and counts. ``counts_not_suppressed``, - ``values_not_suppressed`` and ``cvs_not_suppressed`` are the - equivalents with the counts less than or equal to ``suppress_count`` - removed. - - :param suppress_count: value with a count of this or fewer will be excluded - from the suppressed values. - """ - - def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: - """Initialise a ValueGatherer.""" - values = [] # All values found - counts = [] # The number or each value - cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" - values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times - counts_not_suppressed = [] # The number for each value not suppressed - cvs_not_suppressed: list[ - dict[str, Any] - ] = [] # list of dicts with keys "v" and "count" - for result in results: - c = result.f - if c != 0: - counts.append(c) - v = result.v - if isinstance(v, decimal.Decimal): - v = float(v) - values.append(v) - cvs.append({"value": v, "count": c}) - if suppress_count < c: - counts_not_suppressed.append(c) - v = result.v - if isinstance(v, decimal.Decimal): - v = float(v) - values_not_suppressed.append(v) - cvs_not_suppressed.append({"value": v, "count": c}) - self.values = values - self.counts = counts - self.cvs = cvs - self.values_not_suppressed = values_not_suppressed - self.counts_not_suppressed = counts_not_suppressed - self.cvs_not_suppressed = cvs_not_suppressed - - -class ChoiceGeneratorFactory(GeneratorFactory): - """All generators that want an average and standard deviation.""" - - SAMPLE_COUNT = MAXIMUM_CHOICES - SUPPRESS_COUNT = 5 - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate to these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - column_name = column.name - table_name = column.table.name - generators = [] - with engine.connect() as connection: - results = connection.execute( - text( - f"SELECT {column_name} AS v, COUNT({column_name})" - f" AS f FROM {table_name} GROUP BY v" - f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" - ) - ) - if results is not None and results.rowcount <= MAXIMUM_CHOICES: - vg = ValueGatherer(results, self.SUPPRESS_COUNT) - if vg.counts: - generators += [ - ZipfChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - UniformChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - WeightedChoiceGenerator( - table_name, column_name, vg.cvs, vg.counts - ), - ] - results = connection.execute( - text( - f"SELECT v, COUNT(v) AS f FROM" - f" (SELECT {column_name} as v FROM {table_name}" - f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" - f" AS _inner GROUP BY v ORDER BY f DESC" - ) - ) - if results is not None: - vg = ValueGatherer(results, self.SUPPRESS_COUNT) - if vg.counts: - generators += [ - ZipfChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - UniformChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - WeightedChoiceGenerator( - table_name, column_name, vg.cvs, vg.counts - ), - ] - generators += [ - ZipfChoiceGenerator( - table_name, - column_name, - vg.values, - vg.counts, - sample_count=self.SAMPLE_COUNT, - ), - UniformChoiceGenerator( - table_name, - column_name, - vg.values, - vg.counts, - sample_count=self.SAMPLE_COUNT, - ), - WeightedChoiceGenerator( - table_name, - column_name, - vg.cvs, - vg.counts, - sample_count=self.SAMPLE_COUNT, - ), - ] - if vg.counts_not_suppressed: - generators += [ - ZipfChoiceGenerator( - table_name, - column_name, - vg.values_not_suppressed, - vg.counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - UniformChoiceGenerator( - table_name, - column_name, - vg.values_not_suppressed, - vg.counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - WeightedChoiceGenerator( - table_name=table_name, - column_name=column_name, - values=vg.cvs_not_suppressed, - counts=vg.counts_not_suppressed, - sample_count=self.SAMPLE_COUNT, - suppress_count=self.SUPPRESS_COUNT, - ), - ] - return generators - - -class ConstantGenerator(Generator): - """Generator that always produces the same value.""" - - def __init__(self, value: Any) -> None: - """Initialise the ConstantGenerator.""" - super().__init__() - self.value = value - self.repr = repr(value) - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.constant" - - def nominal_kwargs(self) -> dict[str, str]: - """Get the arguments to be entered into ``config.yaml``.""" - return {"value": self.repr} - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return {"value": self.value} - - def generate_data(self, count: int) -> list[Any]: - """Generate ``count`` random data points for this column.""" - return [self.value for _ in range(count)] - - -class ConstantGeneratorFactory(GeneratorFactory): - """Just the null generator.""" - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators appropriate for these columns.""" - if len(columns) != 1: - return [] - column = columns[0] - if column.nullable: - return [ConstantGenerator(None)] - c_type = get_column_type(column) - if isinstance(c_type, String): - return [ConstantGenerator("")] - if isinstance(c_type, Numeric): - return [ConstantGenerator(0.0)] - if isinstance(c_type, Integer): - return [ConstantGenerator(0)] - return [] - - -class MultivariateNormalGenerator(Generator): - """Generator of multiple values drawn from a multivariate normal distribution.""" - - def __init__( - self, - table_name: str, - column_names: list[str], - query: str, - covariates: RowMapping, - function_name: str, - ) -> None: - """Initialise a MultivariateNormalGenerator.""" - self._table = table_name - self._columns = column_names - self._query = query - self._covariates = covariates - self._function_name = function_name - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen." + self._function_name - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', - } - - def custom_queries(self) -> dict[str, Any]: - """Get the queries the generators need to call.""" - cols = ", ".join(self._columns) - return { - f"auto__cov__{self._table}": { - "comment": ( - f"Means and covariate matrix for the columns {cols}," - " so that we can produce the relatedness between these in the fake data." - ), - "query": self._query, - } - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return {"cov": self._covariates} - - def generate_data(self, count: int) -> list[Any]: - """Generate 'count' random data points for this column.""" - return [ - getattr(dist_gen, self._function_name)(self._covariates) - for _ in range(count) - ] - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default - - -class MultivariateNormalGeneratorFactory(GeneratorFactory): - """Normal distribution generator factory.""" - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "multivariate_normal" - - def query_predicate(self, column: Column) -> str: - """Get the SQL expression for whether this column should be queried.""" - return column.name + " IS NOT NULL" - - def query_var(self, column: str) -> str: - """Get the SQL expression of the value to query for this column.""" - return column - - def query( - self, - table: str, - columns: list[Column], - predicates: list[str] = [], - group_by_clause: str = "", - constant_clauses: str = "", - constants: str = "", - suppress_count: int = 1, - sample_count: int | None = None, - ) -> str: - """ - Get a query for the basics for multivariate normal/lognormal parameters. - - :param table: The name of the table to be queried. - :param columns: The columns in the multivariate distribution. - :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. - :param group_by_clause: Any GROUP BY clause (starting with " GROUP BY " if not ""). - :param constant_clauses: Extra output columns in the outer SELECT clause, such - as ", _q.column_one AS k1, _q.column_two AS k2". Note the initial comma. - :param constants: Extra output columns in the inner SELECT clause. Used to - deliver columns to the outer select, such as ", column_one, column_two". - Note the initial comma. - :param suppress_count: a group smaller than this will be suppressed. - :param sample_count: this many samples will be taken from each partition. - """ - preds = [self.query_predicate(col) for col in columns] + predicates - where = " WHERE " + " AND ".join(preds) if preds else "" - avgs = "".join( - f", AVG({self.query_var(col.name)}) AS m{i}" - for i, col in enumerate(columns) - ) - multiples = "".join( - f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" - for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[: iy + 1]) - ) - means = "".join(f", _q.m{i}" for i in range(len(columns))) - covs = "".join( - ( - f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" - f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" - ) - for iy in range(len(columns)) - for ix in range(iy + 1) - ) - if sample_count is None: - subquery = table + where - else: - subquery = ( - f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" - f" LIMIT {sample_count}) AS _sampled" - ) - # if there are any numeric columns we need at least two rows to make any (co)variances at all - suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" - return ( - f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" - f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" - f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" - ) - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get the generators for these columns.""" - # For the case of one column we'll use GaussianGenerator - if len(columns) < 2: - return [] - # All columns must be numeric - for c in columns: - ct = get_column_type(c) - if not isinstance(ct, Numeric) and not isinstance(ct, Integer): - return [] - column_names = [c.name for c in columns] - table = columns[0].table.name - query = self.query(table, columns) - with engine.connect() as connection: - try: - covariates = connection.execute(text(query)).mappings().first() - except Exception as e: - logger.debug("SQL query %s failed with error %s", query, e) - return [] - if not covariates or covariates["c0_0"] is None: - return [] - return [ - MultivariateNormalGenerator( - table, - column_names, - query, - covariates, - self.function_name(), - ) - ] - - -class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): - """Multivariate lognormal generator factory.""" - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "multivariate_lognormal" - - def query_predicate(self, column: Column) -> str: - """Get the SQL expression for whether this column should be queried.""" - return f"COALESCE(0 < {column.name}, FALSE)" - - def query_var(self, column: str) -> str: - """Get the expression to query for, for this column.""" - return f"LN({column})" - - -def text_list(items: Iterable[str]) -> str: - """Concatenate the items with commas and one "and".""" - item_i = iter(items) - try: - last_item = next(item_i) - except StopIteration: - return "" - try: - so_far = next(item_i) - except StopIteration: - return last_item - for item in item_i: - so_far += ", " + last_item - last_item = item - return so_far + " and " + last_item - - -@dataclass -class RowPartition: - """A partition where all the rows have the same pattern of NULLs.""" - - query: str - # list of numeric columns - included_numeric: list[Column] - # map of indices to column names that are being grouped by. - # The indices are indices of where they need to be inserted into - # the generator outputs. - included_choice: dict[int, str] - # map of column names to clause that defines the partition - # such as "mycolumn IS NULL" - excluded_columns: dict[str, str] - # map of constant outputs that need to be inserted into the - # list of included column values (so once the generator has - # been run and the included_choice values have been - # added): {index: value} - constant_outputs: dict[int, Any] - # The actual covariates from the source database - covariates: Sequence[RowMapping] - - def comment(self) -> str: - """Make an appropriate comment for this partition.""" - caveat = "" - if self.included_choice: - caveat = f" (for each possible value of {text_list(self.included_choice.values())})" - if not self.included_numeric: - return f"Number of rows for which {text_list(self.excluded_columns.values())}{caveat}" - if not self.excluded_columns: - where = "" - else: - where = f" where {text_list(self.excluded_columns.values())}" - if len(self.included_numeric) == 1: - return ( - f"Mean and variance for column {self.included_numeric[0].name}{where}." - ) - return ( - "Means and covariate matrix for the columns " - f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" - " produce the relatedness between these in the fake data." - ) - - -class NullPartitionedNormalGenerator(Generator): - """ - A generator of mixed numeric and non-numeric data. - - Generates data that matches the source data in - missingness, choice of non-numeric data and numeric - data. - - For the numeric data to be generated, samples of rows for each - combination of non-numeric values and missingness. If any such - combination has only one line in the source data (or sample of - the source data if sampling), it will not be generated as a - covariate matrix cannot be generated from one source row - (although if the data is all non-numeric values and nulls, single - rows are used because no covariate matrix is required for this). - """ - - def __init__( - self, - query_name: str, - partitions: dict[int, RowPartition], - function_name: str = "grouped_multivariate_lognormal", - name_suffix: str | None = None, - partition_count_query: str | None = None, - partition_counts: Iterable[RowMapping] = [], - partition_count_comment: str | None = None, - ): - """Initialise a NullPartitionedNormalGenerator.""" - self._query_name = query_name - self._partitions = partitions - self._function_name = function_name - self._partition_count_query = partition_count_query - self._partition_counts = [dict(pc) for pc in partition_counts] - self._partition_count_comment = partition_count_comment - if name_suffix: - self._name = f"null-partitioned {function_name} [{name_suffix}]" - else: - self._name = f"null-partitioned {function_name}" - - def name(self) -> str: - """Get the name of the generator.""" - return self._name - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "dist_gen.alternatives" - - def _nominal_kwargs_with_combinations( - self, index: int, partition: RowPartition - ) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml`` for a single partition.""" - count = ( - 'sum(r["count"] for r in' - f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' - ) - if not partition.included_numeric and not partition.included_choice: - return { - "count": count, - "name": '"constant"', - "params": {"value": [None] * len(partition.constant_outputs)}, - } - covariates = { - "covs": f'SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"]' - } - if not partition.constant_outputs: - return { - "count": count, - "name": f'"{self._function_name}"', - "params": covariates, - } - return { - "count": count, - "name": '"with_constants_at"', - "params": { - "constants_at": partition.constant_outputs, - "subgen": f'"{self._function_name}"', - "params": covariates, - }, - } - - def _count_query_name(self) -> str: - return f"auto__cov__{self._query_name}__counts" - - def nominal_kwargs(self) -> dict[str, Any]: - """Get the arguments to be entered into ``config.yaml``.""" - return { - "alternative_configs": [ - self._nominal_kwargs_with_combinations(index, self._partitions[index]) - for index in range(len(self._partitions)) - ], - "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', - } - - def custom_queries(self) -> dict[str, Any]: - """Get the queries the generators need to call.""" - partitions = { - f"auto__cov__{self._query_name}__alt_{index}": { - "comment": partition.comment(), - "query": partition.query, - } - for index, partition in self._partitions.items() - } - if not self._partition_count_query: - return partitions - return { - self._count_query_name(): { - "comment": self._partition_count_comment, - "query": self._partition_count_query, - }, - **partitions, - } - - def _actual_kwargs_with_combinations( - self, partition: RowPartition - ) -> dict[str, Any]: - count = sum(row["count"] for row in partition.covariates) - if not partition.included_numeric and not partition.included_choice: - return { - "count": count, - "name": "constant", - "params": {"value": [None] * len(partition.excluded_columns)}, - } - covariates = { - "covs": partition.covariates, - } - if not partition.constant_outputs: - return { - "count": count, - "name": self._function_name, - "params": covariates, - } - return { - "count": count, - "name": "with_constants_at", - "params": { - "constants_at": partition.constant_outputs, - "subgen": self._function_name, - "params": covariates, - }, - } - - def actual_kwargs(self) -> dict[str, Any]: - """Get the kwargs (summary statistics) this generator was instantiated with.""" - return { - "alternative_configs": [ - self._actual_kwargs_with_combinations(self._partitions[index]) - for index in range(len(self._partitions)) - ], - "counts": self._partition_counts, - } - - def generate_data(self, count: int) -> list[Any]: - """Generate 'count' random data points for this column.""" - kwargs = self.actual_kwargs() - return [dist_gen.alternatives(**kwargs) for _ in range(count)] - - def fit(self, default: float = -1) -> float: - """Get this generator's fit against the real data.""" - return default - - -def is_numeric(col: Column) -> bool: - """Test if this column stores a numeric value.""" - ct = get_column_type(col) - return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys - - -def powerset(xs: list[T]) -> Iterable[Iterable[T]]: - """Get a list of all sublists of ``input``.""" - return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) - - -@dataclass -class NullableColumn: - """A reference to a nullable column whose nullability is part of a partitioning.""" - - column: Column - # The bit (power of two) of the number of the partition in the partition sizes list - bitmask: int - - -class NullPatternPartition: - """Get the definition of a partition (in other words, what makes it not another partition).""" - - def __init__( - self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] - ): - """Initialise a pattern of nulls which can be queried for.""" - self.index = sum(nc.bitmask for nc in partition_nonnulls) - nonnull_columns = {nc.column.name for nc in partition_nonnulls} - self.included_numeric: list[Column] = [] - self.included_choice: dict[int, str] = {} - self.group_by_clause = "" - self.constant_clauses = "" - self.constants = "" - self.excluded: dict[str, str] = {} - self.predicates: list[str] = [] - self.nones: dict[int, None] = {} - for col_index, column in enumerate(columns): - col_name = column.name - if col_name in nonnull_columns or not column.nullable: - if is_numeric(column): - self.included_numeric.append(column) - else: - index = len(self.included_numeric) + len(self.included_choice) - self.included_choice[index] = col_name - if self.group_by_clause: - self.group_by_clause += ", " + col_name - else: - self.group_by_clause = " GROUP BY " + col_name - self.constant_clauses += f", _q.{col_name} AS k{index}" - self.constants += ", " + col_name - else: - self.excluded[col_name] = f"{col_name} IS NULL" - self.predicates.append(f"{col_name} IS NULL") - self.nones[col_index] = None - - -class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): - """Produces null partitioned generators, for complex interdependent data.""" - - SAMPLE_COUNT = MAXIMUM_CHOICES - SUPPRESS_COUNT = 5 - EMPTY_RESULT = [ - RowMapping( - parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), - processors=None, - key_to_index={"count": 0}, - data=(0,), - ) - ] - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "grouped_multivariate_normal" - - def query_predicate(self, column: Column) -> str: - """Get a SQL expression that is true when ``column`` is available for analysis.""" - if is_numeric(column): - # x <> x + 1 ensures that x is not infinity or NaN - return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" - return f"{column.name} IS NOT NULL" - - def query_var(self, column: str) -> str: - """Return the expression we are querying for in this column.""" - return column - - def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: - """Get a list of nullable columns together with bitmasks.""" - out: list[NullableColumn] = [] - for col in columns: - if col.nullable: - out.append( - NullableColumn( - column=col, - bitmask=2 ** len(out), - ) - ) - return out - - def get_partition_count_query( - self, ncs: list[NullableColumn], table: str, where: str | None = None - ) -> str: - """ - Get a SQL expression returning columns ``count`` and ``index``. - - Each row returned represents one of the null pattern partitions. - ``index`` is the bitmask of all those nullable columns that are not null for - this partition, and ``count`` is the total number of rows in this partition. - """ - index_exp = " + ".join( - f"CASE WHEN {self.query_predicate(nc.column)} THEN {nc.bitmask} ELSE 0 END" - for nc in ncs - ) - if where is None: - return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' - return ( - 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' - f' {index_exp} AS "index"' - f' FROM {table} GROUP BY "index") AS _q {where}' - ) - - def get_generators( - self, columns: list[Column], engine: Engine - ) -> Sequence[Generator]: - """Get any appropriate generators for these columns.""" - if len(columns) < 2: - return [] - nullable_columns = self.get_nullable_columns(columns) - if not nullable_columns: - return [] - table = columns[0].table.name - query_name = f"{table}__{columns[0].name}" - # Partitions for minimal suppression and no sampling - row_partitions_maximal: dict[int, RowPartition] = {} - # Partitions for normal suppression and severe sampling - row_partitions_ss: dict[int, RowPartition] = {} - for partition_nonnulls in powerset(nullable_columns): - partition_def = NullPatternPartition(columns, partition_nonnulls) - query = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - ) - row_partitions_maximal[partition_def.index] = RowPartition( - query, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - query = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - suppress_count=self.SUPPRESS_COUNT, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_ss[partition_def.index] = RowPartition( - query, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - gens: list[Generator] = [] - try: - with engine.connect() as connection: - partition_query_max = self.get_partition_count_query( - nullable_columns, table - ) - partition_count_max_results = ( - connection.execute(text(partition_query_max)).mappings().fetchall() - ) - count_comment = ( - "Number of rows for each combination of the columns" - f" { {nc.column.name for nc in nullable_columns} }" - f" of the table {table} being null" - ) - if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) - ) - partition_query_ss = self.get_partition_count_query( - nullable_columns, - table, - where=f"WHERE {self.SUPPRESS_COUNT} < count", - ) - partition_count_ss_results = ( - connection.execute(text(partition_query_ss)).mappings().fetchall() - ) - if self._execute_partition_queries(connection, row_partitions_ss): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - ) - ) - except sqlalchemy.exc.DatabaseError as exc: - logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) - return [] - return gens - - def _execute_partition_queries( - self, - connection: Connection, - partitions: dict[int, RowPartition], - ) -> bool: - """ - Execute the query in each partition, filling in the covariates. - - :return: True if all the partitions work, False if any of them fail. - """ - found_nonzero = False - for rp in partitions.values(): - covs = connection.execute(text(rp.query)).mappings().fetchall() - if not covs or covs.count == 0 or covs[0]["count"] is None: - rp.covariates = self.EMPTY_RESULT - else: - rp.covariates = covs - found_nonzero = True - return found_nonzero - - -class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): - """ - A generator for numeric and non-numeric columns. - - Any values could be null, the distributions of the nonnull numeric columns - depend on each other and the other non-numeric column values. - """ - - def function_name(self) -> str: - """Get the name of the generator function to call.""" - return "grouped_multivariate_lognormal" - - def query_predicate(self, column: Column) -> str: - """Get the SQL expression testing if the value in this column should be used.""" - if is_numeric(column): - # x <> x + 1 ensures that x is not infinity or NaN - return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" - return f"{column.name} IS NOT NULL" - - def query_var(self, column: str) -> str: - """Get the variable or expression we are querying for this column.""" - return f"LN({column})" - - -@lru_cache(1) -def everything_factory() -> GeneratorFactory: - """Get a factory that encapsulates all the other factories.""" - return MultiGeneratorFactory( - [ - MimesisStringGeneratorFactory(), - MimesisIntegerGeneratorFactory(), - MimesisFloatGeneratorFactory(), - MimesisDateGeneratorFactory(), - MimesisDateTimeGeneratorFactory(), - MimesisTimeGeneratorFactory(), - ContinuousDistributionGeneratorFactory(), - ContinuousLogDistributionGeneratorFactory(), - ChoiceGeneratorFactory(), - ConstantGeneratorFactory(), - MultivariateNormalGeneratorFactory(), - MultivariateLogNormalGeneratorFactory(), - NullPartitionedNormalGeneratorFactory(), - NullPartitionedLogNormalGeneratorFactory(), - ] - ) diff --git a/datafaker/generators/__init__.py b/datafaker/generators/__init__.py new file mode 100644 index 0000000..c08d120 --- /dev/null +++ b/datafaker/generators/__init__.py @@ -0,0 +1,53 @@ +"""Generators write generator function definitions and queries into config.yaml.""" + +from functools import lru_cache + +from datafaker.generators.base import ( + ConstantGeneratorFactory, + GeneratorFactory, + MultiGeneratorFactory, +) +from datafaker.generators.choice import ( + ChoiceGeneratorFactory, +) +from datafaker.generators.continuous import ( + ContinuousDistributionGeneratorFactory, + ContinuousLogDistributionGeneratorFactory, + MultivariateNormalGeneratorFactory, + MultivariateLogNormalGeneratorFactory, +) +from datafaker.generators.mimesis import ( + MimesisStringGeneratorFactory, + MimesisIntegerGeneratorFactory, + MimesisFloatGeneratorFactory, + MimesisDateGeneratorFactory, + MimesisDateTimeGeneratorFactory, + MimesisTimeGeneratorFactory, +) +from datafaker.generators.partitioned import( + NullPartitionedNormalGeneratorFactory, + NullPartitionedLogNormalGeneratorFactory, +) + + +@lru_cache(1) +def everything_factory() -> GeneratorFactory: + """Get a factory that encapsulates all the other factories.""" + return MultiGeneratorFactory( + [ + MimesisStringGeneratorFactory(), + MimesisIntegerGeneratorFactory(), + MimesisFloatGeneratorFactory(), + MimesisDateGeneratorFactory(), + MimesisDateTimeGeneratorFactory(), + MimesisTimeGeneratorFactory(), + ContinuousDistributionGeneratorFactory(), + ContinuousLogDistributionGeneratorFactory(), + ChoiceGeneratorFactory(), + ConstantGeneratorFactory(), + MultivariateNormalGeneratorFactory(), + MultivariateLogNormalGeneratorFactory(), + NullPartitionedNormalGeneratorFactory(), + NullPartitionedLogNormalGeneratorFactory(), + ] + ) diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py new file mode 100644 index 0000000..1adcb9c --- /dev/null +++ b/datafaker/generators/base.py @@ -0,0 +1,417 @@ +"""Basic Generators and factories.""" + +import re +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any, Sequence, Union + +import mimesis +import mimesis.locales +import sqlalchemy +from sqlalchemy import Column, Engine, text +from sqlalchemy.types import Integer, Numeric, String, TypeEngine +from typing_extensions import Self + +from datafaker.base import DistributionGenerator +from datafaker.utils import T, logger + +NumericType = Union[int, float] + + +dist_gen = DistributionGenerator() +generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + + +class Generator(ABC): + """ + Random data generator. + + A generator is specific to a particular column in a particular table in + a particluar database. + + A generator knows how to fetch its summary data from the database, how to calculate + its fit (if apropriate) and which function actually does the generation. + + It also knows these summary statistics for the column it was instantiated on, + and therefore knows how to generate fake data for that column. + """ + + @abstractmethod + def function_name(self) -> str: + """Get the name of the generator function to put into df.py.""" + + def name(self) -> str: + """ + Get the name of the generator. + + Usually the same as the function name, but can be different to distinguish + between generators that have the same function but different queries. + """ + return self.function_name() + + @abstractmethod + def nominal_kwargs(self) -> dict[str, str]: + """ + Get the kwargs the generator wants to be called with. + + The values will tend to be references to something in the src-stats.yaml + file. + For example {"avg_age": 'SRC_STATS["auto__patient"]["results"][0]["age_mean"]'} will + provide the value stored in src-stats.yaml as + SRC_STATS["auto__patient"]["results"][0]["age_mean"] as the "avg_age" argument + to the generator function. + """ + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """ + Get the SQL clauses to add to a SELECT ... FROM {table} query. + + Will add to SRC_STATS["auto__{table}"] + For example { + "count": { + "clause": "COUNT(*)", + "comment": "number of rows in table {table}" + }, "avg_thiscolumn": { + "clause": "AVG(thiscolumn)", + "comment": "Average value of thiscolumn in table {table}" + }} + will make the clause become: + "SELECT COUNT(*) AS count, AVG(thiscolumn) AS avg_thiscolumn FROM thistable" + and this will populate SRC_STATS["auto__thistable"]["results"][0]["count"] and + SRC_STATS["auto__thistable"]["results"][0]["avg_thiscolumn"] in the src-stats.yaml file. + """ + return {} + + def custom_queries(self) -> dict[str, dict[str, str]]: + """ + Get the SQL queries to add to SRC_STATS. + + Should be used for queries that do not follow the SELECT ... FROM table format + using aggregate queries, because these should use select_aggregate_clauses. + + For example {"myquery": { + "query": "SELECT one, too AS two FROM mytable WHERE too > 1", + "comment": "big enough one and two from table mytable" + }} + will populate SRC_STATS["myquery"]["results"][0]["one"] + and SRC_STATS["myquery"]["results"][0]["two"] + in the src-stats.yaml file. + + Keys should be chosen to minimize the chances of clashing with other queries, + for example "auto__{table}__{column}__{queryname}" + """ + return {} + + @abstractmethod + def actual_kwargs(self) -> dict[str, Any]: + """ + Get the kwargs (summary statistics) this generator is instantiated with. + + This must match `nominal_kwargs` in structure. + """ + + @abstractmethod + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + + def fit(self, default: float = -1) -> float: + """ + Return a value representing how well the distribution fits the real source data. + + 0.0 means "perfectly". + Returns default if no fitness has been defined. + """ + return default + + +class PredefinedGenerator(Generator): + """Generator built from an existing config.yaml.""" + + SELECT_AGGREGATE_RE = re.compile(r"SELECT (.*) FROM ([A-Za-z_][A-Za-z0-9_]*)") + AS_CLAUSE_RE = re.compile(r" *(.+) +AS +([A-Za-z_][A-Za-z0-9_]*) *") + SRC_STAT_NAME_RE = re.compile(r'\bSRC_STATS\["([^]]*)"\].*') + + def _get_src_stats_mentioned(self, val: Any) -> set[str]: + if not val: + return set() + if isinstance(val, str): + ss = self.SRC_STAT_NAME_RE.match(val) + if ss: + ss_name = ss.group(1) + logger.debug("Found SRC_STATS reference %s", ss_name) + return set([ss_name]) + logger.debug("Value %s does not seem to be a SRC_STATS reference", val) + return set() + if isinstance(val, list): + return set.union(*(self._get_src_stats_mentioned(v) for v in val)) + if isinstance(val, dict): + return set.union(*(self._get_src_stats_mentioned(v) for v in val.values())) + return set() + + def __init__( + self, + table_name: str, + generator_object: Mapping[str, Any], + config: Mapping[str, Any], + ): + """ + Initialise a generator from a config.yaml. + + :param config: The entire configuration. + :param generator_object: The part of the configuration at tables.*.row_generators + """ + logger.debug( + "Creating a PredefinedGenerator %s from table %s", + generator_object["name"], + table_name, + ) + self._table_name = table_name + self._name: str = generator_object["name"] + self._kwn: dict[str, str] = generator_object.get("kwargs", {}) + self._src_stats_mentioned = self._get_src_stats_mentioned(self._kwn) + # Need to deal with this somehow (or remove it from the schema) + self._argn: list[str] = generator_object.get("args", []) + self._select_aggregate_clauses: dict[str, dict[str, str | Any]] = {} + self._custom_queries = {} + for sstat in config.get("src-stats", []): + name: str = sstat["name"] + dpq = sstat.get("dp-query", None) + query = sstat.get( + "query", dpq + ) # ... should we really be combining query and dp-query? + comments = sstat.get("comments", []) + if name in self._src_stats_mentioned: + logger.debug("Found a src-stats entry for %s", name) + # This query is one that this generator is interested in + sam = None if query is None else self.SELECT_AGGREGATE_RE.match(query) + # sam.group(2) is the table name from the FROM clause of the query + if sam and name == f"auto__{sam.group(2)}": + # name is auto__{table_name}, so it's a select_aggregate, + # so we split up its clauses + sacs = [ + self.AS_CLAUSE_RE.match(clause) + for clause in sam.group(1).split(",") + ] + # Work out what select_aggregate_clauses this represents + for sac in sacs: + if sac is not None: + comment = comments.pop() if comments else None + self._select_aggregate_clauses[sac.group(2)] = { + "clause": sac.group(1), + "comment": comment, + } + else: + # some other name, so must be a custom query + logger.debug("Custom query %s is '%s'", name, query) + self._custom_queries[name] = { + "query": query, + "comment": comments[0] if comments else None, + } + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" + return self._kwn + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + return self._select_aggregate_clauses + + def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" + return self._custom_queries + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + # Run the queries from nominal_kwargs + # ... + logger.error("PredefinedGenerator.actual_kwargs not implemented yet") + return {} + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + # Call the function if we can. This could be tricky... + # ... + logger.error("PredefinedGenerator.generate_data not implemented yet") + return [] + + +class GeneratorFactory(ABC): + """A factory for making generators appropriate for a database column.""" + + @abstractmethod + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + + +def fit_from_buckets(xs: Sequence[NumericType], ys: Sequence[NumericType]) -> float: + """Calculate the fit by comparing a pair of lists of buckets.""" + sum_diff_squared = sum(map(lambda t, a: (t - a) * (t - a), xs, ys)) + count = len(ys) + return sum_diff_squared / (count * count) + + +class Buckets: + """ + Measured buckets for a real distribution. + + Finds the real distribution of continuous data so that we can measure + the fit of generators against it. + """ + + def __init__( + self, + engine: Engine, + table_name: str, + column_name: str, + mean: float, + stddev: float, + count: int, + ): + """Initialise a Buckets object.""" + with engine.connect() as connection: + raw_buckets = connection.execute( + text( + f"SELECT COUNT({column_name}) AS f," + f" FLOOR(({column_name} - {mean - 2 * stddev})/{stddev / 2}) AS b" + f" FROM {table_name} GROUP BY b" + ) + ) + self.buckets: Sequence[int] = [0] * 10 + for rb in raw_buckets: + if rb.b is not None: + bucket = min(9, max(0, int(rb.b) + 1)) + self.buckets[bucket] += rb.f / count + self.mean = mean + self.stddev = stddev + + @classmethod + def make_buckets( + cls, engine: Engine, table_name: str, column_name: str + ) -> Self | None: + """ + Construct a Buckets object. + + Calculates the mean and standard deviation of the values in the column + specified and makes ten buckets, centered on the mean and each half + a standard deviation wide (except for the end two that extend to + infinity). Each bucket will be set to the count of the number of values + in the column within that bucket. + """ + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT AVG({column_name}) AS mean," + f" STDDEV({column_name}) AS stddev," + f" COUNT({column_name}) AS count FROM {table_name}" + ) + ).first() + if result is None or result.stddev is None or getattr(result, "count") < 2: + return None + try: + buckets = cls( + engine, + table_name, + column_name, + result.mean, + result.stddev, + getattr(result, "count"), + ) + except sqlalchemy.exc.DatabaseError as exc: + logger.debug("Failed to instantiate Buckets object: %s", exc) + return None + return buckets + + def fit_from_counts(self, bucket_counts: Sequence[float]) -> float: + """Figure out the fit from bucket counts from the generator distribution.""" + return fit_from_buckets(self.buckets, bucket_counts) + + def fit_from_values(self, values: list[float]) -> float: + """Figure out the fit from samples from the generator distribution.""" + buckets = [0] * 10 + x = self.mean - 2 * self.stddev + w = self.stddev / 2 + for v in values: + b = min(9, max(0, int((v - x) / w))) + buckets[b] += 1 + return self.fit_from_counts(buckets) + + +class MultiGeneratorFactory(GeneratorFactory): + """A composite factory.""" + + def __init__(self, factories: list[GeneratorFactory]): + """Initialise a MultiGeneratorFactory.""" + super().__init__() + self.factories = factories + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + return [ + generator + for factory in self.factories + for generator in factory.get_generators(columns, engine) + ] + + +def get_column_type(column: Column) -> TypeEngine: + """Get the type of the column, generic if possible.""" + try: + return column.type.as_generic() + except NotImplementedError: + return column.type + + +class ConstantGenerator(Generator): + """Generator that always produces the same value.""" + + def __init__(self, value: Any) -> None: + """Initialise the ConstantGenerator.""" + super().__init__() + self.value = value + self.repr = repr(value) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.constant" + + def nominal_kwargs(self) -> dict[str, str]: + """Get the arguments to be entered into ``config.yaml``.""" + return {"value": self.repr} + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {"value": self.value} + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self.value for _ in range(count)] + + +class ConstantGeneratorFactory(GeneratorFactory): + """Just the null generator.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate for these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + if column.nullable: + return [ConstantGenerator(None)] + c_type = get_column_type(column) + if isinstance(c_type, String): + return [ConstantGenerator("")] + if isinstance(c_type, Numeric): + return [ConstantGenerator(0.0)] + if isinstance(c_type, Integer): + return [ConstantGenerator(0)] + return [] diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py new file mode 100644 index 0000000..61147a1 --- /dev/null +++ b/datafaker/generators/choice.py @@ -0,0 +1,398 @@ +"""Generator factories for making generators for choices of values.""" + +import decimal +import math +import typing +from abc import abstractmethod +from typing import Any, Sequence, Union + +from datafaker.generators.base import ( + Generator, + GeneratorFactory, + dist_gen, + fit_from_buckets, +) +from sqlalchemy import Column, CursorResult, Engine, text + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + + +def zipf_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Get a zipf distribution for a certain number of items. + + :param total: The total number of items to be distributed. + :param bins: The total number of bins to distribute the items into. + :return: A generator of the number of items in each bin, from the + largest to the smallest. + """ + basic_dist = list(map(lambda n: 1 / n, range(1, bins + 1))) + bd_remaining = sum(basic_dist) + for b in basic_dist: + # yield b/bd_remaining of the `total` remaining + if bd_remaining == 0: + yield 0 + else: + x = math.floor(0.5 + total * b / bd_remaining) + bd_remaining -= x * bd_remaining / total + total -= x + yield x + + +class ChoiceGenerator(Generator): + """Base generator for all generators producing choices of items.""" + + STORE_COUNTS = False + + def __init__( + self, + table_name: str, + column_name: str, + values: list[Any], + counts: list[int], + sample_count: int | None = None, + suppress_count: int = 0, + ) -> None: + """Initialise a ChoiceGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.values = values + estimated_counts = self.get_estimated_counts(counts) + self._fit = fit_from_buckets(counts, estimated_counts) + + extra_results = "" + extra_expo = "" + extra_comment = "" + if self.STORE_COUNTS: + extra_results = f", COUNT({column_name}) AS count" + extra_expo = ", count" + extra_comment = " and their counts" + if suppress_count == 0: + if sample_count is None: + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL GROUP BY value" + f" ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name}" + ) + self._annotation = None + else: + self._query = ( + f"SELECT {column_name} AS value{extra_results} FROM" + f" (SELECT {column_name} FROM {table_name}" + f" WHERE {column_name} IS NOT NULL" + f" ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY COUNT({column_name}) DESC" + ) + self._comment = ( + f"The values{extra_comment} that appear in column {column_name}" + f" of a random sample of {sample_count} rows of table {table_name}" + ) + self._annotation = "sampled" + else: + if sample_count is None: + self._query = ( + f"SELECT value{extra_expo} FROM" + f" (SELECT {column_name} AS value, COUNT({column_name}) AS count" + f" FROM {table_name} WHERE {column_name} IS NOT NULL" + f" GROUP BY value ORDER BY count DESC) AS _inner" + f" WHERE {suppress_count} < count" + ) + self._comment = ( + f"All the values{extra_comment} that appear in column {column_name}" + f" of table {table_name} more than {suppress_count} times" + ) + self._annotation = "suppressed" + else: + self._query = ( + f"SELECT value{extra_expo} FROM (SELECT value, COUNT(value) AS count FROM" + f" (SELECT {column_name} AS value FROM {table_name}" + f" WHERE {column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {sample_count})" + f" AS _inner GROUP BY value ORDER BY count DESC)" + f" AS _inner WHERE {suppress_count} < count" + ) + self._comment = ( + f"The values{extra_comment} that appear more than {suppress_count} times" + f" in column {column_name}, out of a random sample of {sample_count} rows" + f" of table {table_name}" + ) + self._annotation = "sampled and suppressed" + + @abstractmethod + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "a": f'SRC_STATS["auto__{self.table_name}__{self.column_name}"]["results"]', + } + + def name(self) -> str: + """Get the name of the generator.""" + n = super().name() + if self._annotation is None: + return n + return f"{n} [{self._annotation}]" + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "a": self.values, + } + + def custom_queries(self) -> dict[str, dict[str, str]]: + """Get the queries the generators need to call.""" + qs = super().custom_queries() + return { + **qs, + f"auto__{self.table_name}__{self.column_name}": { + "query": self._query, + "comment": self._comment, + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default if self._fit is None else self._fit + + +class ZipfChoiceGenerator(ChoiceGenerator): + """Generator producing items in a Zipf distribution.""" + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return list(zipf_distribution(sum(counts), len(counts))) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.zipf_choice" + + def generate_data(self, count: int) -> list[float]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.zipf_choice_direct(self.values, len(self.values)) + for _ in range(count) + ] + + +def uniform_distribution(total: int, bins: int) -> typing.Generator[int, None, None]: + """ + Construct a distribution putting ``total`` items uniformly into ``bins`` bins. + + If they don't fit exactly evenly, the earlier bins will have one more + item than the later bins so the total is as required. + """ + p = total // bins + n = total % bins + for _ in range(0, n): + yield p + 1 + for _ in range(n, bins): + yield p + + +class UniformChoiceGenerator(ChoiceGenerator): + """A generator producing values, each roughly as frequently as each other.""" + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return list(uniform_distribution(sum(counts), len(counts))) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.choice" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.choice_direct(self.values) for _ in range(count)] + + +class WeightedChoiceGenerator(ChoiceGenerator): + """Choice generator that matches the source data's frequency.""" + + STORE_COUNTS = True + + def get_estimated_counts(self, counts: list[int]) -> list[int]: + """Get the counts that we would expect if this distribution was the correct one.""" + return counts + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.weighted_choice" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.weighted_choice(self.values) for _ in range(count)] + + +class ValueGatherer: + """ + Gathers values from a query of values and counts. + + The query must return columns ``v`` for a value and ``f`` for the + count of how many of those values there are. + These values will be gathered into a number of properties: + ``values``: the list of ``v`` values, ``counts``: the list of ``f`` counts + in the same order as ``v``, ``cvs``: list of dicts with keys ``value`` and + ``count`` giving these values and counts. ``counts_not_suppressed``, + ``values_not_suppressed`` and ``cvs_not_suppressed`` are the + equivalents with the counts less than or equal to ``suppress_count`` + removed. + + :param suppress_count: value with a count of this or fewer will be excluded + from the suppressed values. + """ + + def __init__(self, results: CursorResult, suppress_count: int = 0) -> None: + """Initialise a ValueGatherer.""" + values = [] # All values found + counts = [] # The number or each value + cvs: list[dict[str, Any]] = [] # list of dicts with keys "v" and "count" + values_not_suppressed = [] # All values found more than SUPPRESS_COUNT times + counts_not_suppressed = [] # The number for each value not suppressed + cvs_not_suppressed: list[ + dict[str, Any] + ] = [] # list of dicts with keys "v" and "count" + for result in results: + c = result.f + if c != 0: + counts.append(c) + v = result.v + if isinstance(v, decimal.Decimal): + v = float(v) + values.append(v) + cvs.append({"value": v, "count": c}) + if suppress_count < c: + counts_not_suppressed.append(c) + v = result.v + if isinstance(v, decimal.Decimal): + v = float(v) + values_not_suppressed.append(v) + cvs_not_suppressed.append({"value": v, "count": c}) + self.values = values + self.counts = counts + self.cvs = cvs + self.values_not_suppressed = values_not_suppressed + self.counts_not_suppressed = counts_not_suppressed + self.cvs_not_suppressed = cvs_not_suppressed + + +class ChoiceGeneratorFactory(GeneratorFactory): + """All generators that want an average and standard deviation.""" + + SAMPLE_COUNT = MAXIMUM_CHOICES + SUPPRESS_COUNT = 5 + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + column_name = column.name + table_name = column.table.name + generators = [] + with engine.connect() as connection: + results = connection.execute( + text( + f"SELECT {column_name} AS v, COUNT({column_name})" + f" AS f FROM {table_name} GROUP BY v" + f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" + ) + ) + if results is not None and results.rowcount <= MAXIMUM_CHOICES: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), + ] + results = connection.execute( + text( + f"SELECT v, COUNT(v) AS f FROM" + f" (SELECT {column_name} as v FROM {table_name}" + f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" + f" AS _inner GROUP BY v ORDER BY f DESC" + ) + ) + if results is not None: + vg = ValueGatherer(results, self.SUPPRESS_COUNT) + if vg.counts: + generators += [ + ZipfChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + UniformChoiceGenerator( + table_name, column_name, vg.values, vg.counts + ), + WeightedChoiceGenerator( + table_name, column_name, vg.cvs, vg.counts + ), + ] + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + WeightedChoiceGenerator( + table_name, + column_name, + vg.cvs, + vg.counts, + sample_count=self.SAMPLE_COUNT, + ), + ] + if vg.counts_not_suppressed: + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + WeightedChoiceGenerator( + table_name=table_name, + column_name=column_name, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, + sample_count=self.SAMPLE_COUNT, + suppress_count=self.SUPPRESS_COUNT, + ), + ] + return generators diff --git a/datafaker/generators/continuous.py b/datafaker/generators/continuous.py new file mode 100644 index 0000000..a84d965 --- /dev/null +++ b/datafaker/generators/continuous.py @@ -0,0 +1,471 @@ +"""Generator factories for making generators of continuous distributions.""" + +from typing import Any, Sequence + +from datafaker.generators.base import ( + Buckets, + Generator, + GeneratorFactory, + NumericType, + get_column_type, +) +from sqlalchemy import Column, Engine, RowMapping, text +from sqlalchemy.types import Integer, Numeric + +from datafaker.generators.base import dist_gen +from datafaker.utils import logger + + +class ContinuousDistributionGenerator(Generator): + """Base class for generators producing continuous distributions.""" + + expected_buckets: Sequence[NumericType] = [] + + def __init__(self, table_name: str, column_name: str, buckets: Buckets): + """Initialise a ContinuousDistributionGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.buckets = buckets + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "mean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["mean__{self.column_name}"]' + ), + "sd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"]' + f'[0]["stddev__{self.column_name}"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + if self.buckets is None: + return {} + return { + "mean": self.buckets.mean, + "sd": self.buckets.stddev, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + clauses = super().select_aggregate_clauses() + return { + **clauses, + f"mean__{self.column_name}": { + "clause": f"AVG({self.column_name})", + "comment": f"Mean of {self.column_name} from table {self.table_name}", + }, + f"stddev__{self.column_name}": { + "clause": f"STDDEV({self.column_name})", + "comment": f"Standard deviation of {self.column_name} from table {self.table_name}", + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + if self.buckets is None: + return default + return self.buckets.fit_from_counts(self.expected_buckets) + + +class GaussianGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a Gaussian (normal) distribution.""" + + expected_buckets = [ + 0.0227, + 0.0441, + 0.0918, + 0.1499, + 0.1915, + 0.1915, + 0.1499, + 0.0918, + 0.0441, + 0.0227, + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.normal" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.normal(self.buckets.mean, self.buckets.stddev) + for _ in range(count) + ] + + +class UniformGenerator(ContinuousDistributionGenerator): + """Generator producing numbers in a uniform distribution.""" + + expected_buckets = [ + 0, + 0.06698, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.14434, + 0.06698, + 0, + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.uniform_ms" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + dist_gen.uniform_ms(self.buckets.mean, self.buckets.stddev) + for _ in range(count) + ] + + +class ContinuousDistributionGeneratorFactory(GeneratorFactory): + """All generators that want an average and standard deviation.""" + + def _get_generators_from_buckets( + self, + _engine: Engine, + table_name: str, + column_name: str, + buckets: Buckets, + ) -> Sequence[Generator]: + return [ + GaussianGenerator(table_name, column_name, buckets), + UniformGenerator(table_name, column_name, buckets), + ] + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + column_name = column.name + table_name = column.table.name + buckets = Buckets.make_buckets(engine, table_name, column_name) + if buckets is None: + return [] + return self._get_generators_from_buckets( + engine, table_name, column_name, buckets + ) + + +class LogNormalGenerator(Generator): + """Generator producing numbers in a log-normal distribution.""" + + # TODO: figure out the real buckets here (this was from a random sample in R) + expected_buckets = [ + 0, + 0, + 0, + 0.28627, + 0.40607, + 0.14937, + 0.06735, + 0.03492, + 0.01918, + 0.03684, + ] + + def __init__( + self, + table_name: str, + column_name: str, + buckets: Buckets, + logmean: float, + logstddev: float, + ): + """Initialise a LogNormalGenerator.""" + super().__init__() + self.table_name = table_name + self.column_name = column_name + self.buckets = buckets + self.logmean = logmean + self.logstddev = logstddev + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.lognormal" + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [dist_gen.lognormal(self.logmean, self.logstddev) for _ in range(count)] + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "logmean": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logmean__{self.column_name}"]' + ), + "logsd": ( + f'SRC_STATS["auto__{self.table_name}"]["results"][0]' + f'["logstddev__{self.column_name}"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "logmean": self.logmean, + "logsd": self.logstddev, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + clauses = super().select_aggregate_clauses() + return { + **clauses, + f"logmean__{self.column_name}": { + "clause": ( + f"AVG(CASE WHEN 0<{self.column_name} THEN LN({self.column_name})" + " ELSE NULL END)" + ), + "comment": f"Mean of logs of {self.column_name} from table {self.table_name}", + }, + f"logstddev__{self.column_name}": { + "clause": ( + f"STDDEV(CASE WHEN 0<{self.column_name}" + f" THEN LN({self.column_name}) ELSE NULL END)" + ), + "comment": ( + f"Standard deviation of logs of {self.column_name}" + f" from table {self.table_name}" + ), + }, + } + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + if self.buckets is None: + return default + return self.buckets.fit_from_counts(self.expected_buckets) + + +class ContinuousLogDistributionGeneratorFactory(ContinuousDistributionGeneratorFactory): + """All generators that want an average and standard deviation of log data.""" + + def _get_generators_from_buckets( + self, + engine: Engine, + table_name: str, + column_name: str, + buckets: Buckets, + ) -> Sequence[Generator]: + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT AVG(CASE WHEN 0<{column_name} THEN LN({column_name})" + " ELSE NULL END) AS logmean," + f" STDDEV(CASE WHEN 0<{column_name} THEN LN({column_name}) ELSE NULL END)" + f" AS logstddev FROM {table_name}" + ) + ).first() + if result is None or result.logstddev is None: + return [] + return [ + LogNormalGenerator( + table_name, + column_name, + buckets, + float(result.logmean), + float(result.logstddev), + ) + ] + + +class MultivariateNormalGenerator(Generator): + """Generator of multiple values drawn from a multivariate normal distribution.""" + + def __init__( + self, + table_name: str, + column_names: list[str], + query: str, + covariates: RowMapping, + function_name: str, + ) -> None: + """Initialise a MultivariateNormalGenerator.""" + self._table = table_name + self._columns = column_names + self._query = query + self._covariates = covariates + self._function_name = function_name + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen." + self._function_name + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "cov": f'SRC_STATS["auto__cov__{self._table}"]["results"][0]', + } + + def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" + cols = ", ".join(self._columns) + return { + f"auto__cov__{self._table}": { + "comment": ( + f"Means and covariate matrix for the columns {cols}," + " so that we can produce the relatedness between these in the fake data." + ), + "query": self._query, + } + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {"cov": self._covariates} + + def generate_data(self, count: int) -> list[Any]: + """Generate 'count' random data points for this column.""" + return [ + getattr(dist_gen, self._function_name)(self._covariates) + for _ in range(count) + ] + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default + + +class MultivariateNormalGeneratorFactory(GeneratorFactory): + """Normal distribution generator factory.""" + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "multivariate_normal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" + return column.name + " IS NOT NULL" + + def query_var(self, column: str) -> str: + """Get the SQL expression of the value to query for this column.""" + return column + + def query( + self, + table: str, + columns: list[Column], + predicates: list[str] = [], + group_by_clause: str = "", + constant_clauses: str = "", + constants: str = "", + suppress_count: int = 1, + sample_count: int | None = None, + ) -> str: + """ + Get a query for the basics for multivariate normal/lognormal parameters. + + :param table: The name of the table to be queried. + :param columns: The columns in the multivariate distribution. + :param and_where: Additional where clause. If not ``""`` should begin with ``" AND "``. + :param group_by_clause: Any GROUP BY clause (starting with " GROUP BY " if not ""). + :param constant_clauses: Extra output columns in the outer SELECT clause, such + as ", _q.column_one AS k1, _q.column_two AS k2". Note the initial comma. + :param constants: Extra output columns in the inner SELECT clause. Used to + deliver columns to the outer select, such as ", column_one, column_two". + Note the initial comma. + :param suppress_count: a group smaller than this will be suppressed. + :param sample_count: this many samples will be taken from each partition. + """ + preds = [self.query_predicate(col) for col in columns] + predicates + where = " WHERE " + " AND ".join(preds) if preds else "" + avgs = "".join( + f", AVG({self.query_var(col.name)}) AS m{i}" + for i, col in enumerate(columns) + ) + multiples = "".join( + f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" + for iy, coly in enumerate(columns) + for ix, colx in enumerate(columns[: iy + 1]) + ) + means = "".join(f", _q.m{i}" for i in range(len(columns))) + covs = "".join( + ( + f", (_q.s{ix}_{iy} - _q.count * _q.m{ix} * _q.m{iy})" + f"/NULLIF(_q.count - 1, 0) AS c{ix}_{iy}" + ) + for iy in range(len(columns)) + for ix in range(iy + 1) + ) + if sample_count is None: + subquery = table + where + else: + subquery = ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) + # if there are any numeric columns we need at least# + # two rows to make any (co)variances at all + suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" + return ( + f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" + f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" + f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" + ) + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators for these columns.""" + # For the case of one column we'll use GaussianGenerator + if len(columns) < 2: + return [] + # All columns must be numeric + for c in columns: + ct = get_column_type(c) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + column_names = [c.name for c in columns] + table = columns[0].table.name + query = self.query(table, columns) + with engine.connect() as connection: + try: + covariates = connection.execute(text(query)).mappings().first() + except Exception as e: + logger.debug("SQL query %s failed with error %s", query, e) + return [] + if not covariates or covariates["c0_0"] is None: + return [] + return [ + MultivariateNormalGenerator( + table, + column_names, + query, + covariates, + self.function_name(), + ) + ] + + +class MultivariateLogNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Multivariate lognormal generator factory.""" + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "multivariate_lognormal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression for whether this column should be queried.""" + return f"COALESCE(0 < {column.name}, FALSE)" + + def query_var(self, column: str) -> str: + """Get the expression to query for, for this column.""" + return f"LN({column})" diff --git a/datafaker/generators/mimesis.py b/datafaker/generators/mimesis.py new file mode 100644 index 0000000..8d031ab --- /dev/null +++ b/datafaker/generators/mimesis.py @@ -0,0 +1,418 @@ +"""Generators using Mimesis.""" + +from typing import Any, Callable, Sequence, Union + +import mimesis +import mimesis.locales +from datafaker.generators.base import ( + Buckets, + Generator, + GeneratorFactory, + get_column_type, +) +from sqlalchemy import Column, Engine, text +from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time + +from datafaker.base import DistributionGenerator + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + +dist_gen = DistributionGenerator() +generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) + + +class MimesisGeneratorBase(Generator): + """Base class for a generator using Mimesis.""" + + def __init__( + self, + function_name: str, + ): + """ + Initialise a generator that uses Mimesis. + + :param function_name: is relative to 'generic', for example 'person.name'. + """ + super().__init__() + f = generic + for part in function_name.split("."): + if not hasattr(f, part): + raise Exception( + f"Mimesis does not have a function {function_name}: {part} not found" + ) + f = getattr(f, part) + if not callable(f): + raise Exception( + f"Mimesis object {function_name} is not a callable," + " so cannot be used as a generator" + ) + self._name = "generic." + function_name + self._generator_function = f + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self._generator_function() for _ in range(count)] + + +class MimesisGenerator(MimesisGeneratorBase): + """A generator using Mimesis.""" + + def __init__( + self, + function_name: str, + value_fn: Callable[[Any], float] | None = None, + buckets: Buckets | None = None, + ): + """ + Initialise a generator using Mimesis. + + :param function_name: is relative to 'generic', for example 'person.name'. + :param value_fn: Function to convert generator output to floats, if needed. The values + thus produced are compared against the buckets to estimate the fit. + :param buckets: The distribution of string lengths in the real data. If this is None + then the fit method will return None. + """ + super().__init__(function_name) + if buckets is None: + self._fit = None + return + samples = self.generate_data(400) + if value_fn: + samples = [value_fn(s) for s in samples] + self._fit = buckets.fit_from_values(samples) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return self._name + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return {} + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return {} + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default if self._fit is None else self._fit + + +class MimesisGeneratorTruncated(MimesisGenerator): + """A string generator using Mimesis that must fit within a certain number of characters.""" + + def __init__( + self, + function_name: str, + length: int, + value_fn: Callable[[Any], float] | None = None, + buckets: Buckets | None = None, + ): + """Initialise a MimesisGeneratorTruncated.""" + self._length = length + super().__init__(function_name, value_fn, buckets) + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.truncated_string" + + def name(self) -> str: + """Get the name of the generator.""" + return f"{self._name} [truncated to {self._length}]" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "subgen_fn": self._name, + "params": {}, + "length": self._length, + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "subgen_fn": self._name, + "params": {}, + "length": self._length, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [self._generator_function()[: self._length] for _ in range(count)] + + +class MimesisDateTimeGenerator(MimesisGeneratorBase): + """DateTime generator using Mimesis.""" + + def __init__( + self, + column: Column, + function_name: str, + min_year: str, + max_year: str, + start: int, + end: int, + ) -> None: + """ + Initialise a MimesisDateTimeGenerator. + + :param column: The column to generate into + :param function_name: The name of the mimesis function + :param min_year: SQL expression extracting the minimum year + :param min_year: SQL expression extracting the maximum year + :param start: The actual first year found + :param end: The actual last year found + """ + super().__init__(function_name) + self._column = column + self._max_year = max_year + self._min_year = min_year + self._start = start + self._end = end + + @classmethod + def make_singleton( + cls, column: Column, engine: Engine, function_name: str + ) -> Sequence[Generator]: + """Make the appropriate generation configuration for this column.""" + extract_year = f"CAST(EXTRACT(YEAR FROM {column.name}) AS INT)" + max_year = f"MAX({extract_year})" + min_year = f"MIN({extract_year})" + with engine.connect() as connection: + result = connection.execute( + text( + f"SELECT {min_year} AS start, {max_year} AS end FROM {column.table.name}" + ) + ).first() + if result is None or result.start is None or result.end is None: + return [] + return [ + MimesisDateTimeGenerator( + column, + function_name, + min_year, + max_year, + int(result.start), + int(result.end), + ) + ] + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "start": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__start"]' + ), + "end": ( + f'SRC_STATS["auto__{self._column.table.name}"]["results"]' + f'[0]["{self._column.name}__end"]' + ), + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "start": self._start, + "end": self._end, + } + + def select_aggregate_clauses(self) -> dict[str, dict[str, str]]: + """Get the query fragments the generators need to call.""" + return { + f"{self._column.name}__start": { + "clause": self._min_year, + "comment": ( + f"Earliest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), + }, + f"{self._column.name}__end": { + "clause": self._max_year, + "comment": ( + f"Latest year found for column {self._column.name}" + f" in table {self._column.table.name}" + ), + }, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate ``count`` random data points for this column.""" + return [ + self._generator_function(start=self._start, end=self._end) + for _ in range(count) + ] + + +class MimesisStringGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return strings.""" + + GENERATOR_NAMES = [ + "address.calling_code", + "address.city", + "address.continent", + "address.country", + "address.country_code", + "address.postal_code", + "address.province", + "address.street_number", + "address.street_name", + "address.street_suffix", + "person.blood_type", + "person.email", + "person.first_name", + "person.last_name", + "person.full_name", + "person.gender", + "person.language", + "person.nationality", + "person.occupation", + "person.password", + "person.title", + "person.university", + "person.username", + "person.worldview", + "text.answer", + "text.color", + "text.level", + "text.quote", + "text.sentence", + "text.text", + "text.word", + ] + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + column_type = get_column_type(column) + if not isinstance(column_type, String): + return [] + try: + buckets = Buckets.make_buckets( + engine, + column.table.name, + f"LENGTH({column.name})", + ) + fitness_fn = len + except Exception: + # Some column types that appear to be strings (such as enums) + # cannot have their lengths measured. In this case we cannot + # detect fitness using lengths. + buckets = None + fitness_fn = None + length = column_type.length + if length: + return list( + map( + lambda gen: MimesisGeneratorTruncated( + gen, length, fitness_fn, buckets + ), + self.GENERATOR_NAMES, + ) + ) + return list( + map( + lambda gen: MimesisGenerator(gen, fitness_fn, buckets), + self.GENERATOR_NAMES, + ) + ) + + +class MimesisFloatGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return floating point numbers.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + if not isinstance(get_column_type(column), Numeric): + return [] + return list( + map( + MimesisGenerator, + [ + "person.height", + ], + ) + ) + + +class MimesisDateGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return dates.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Date): + return [] + return MimesisDateTimeGenerator.make_singleton(column, engine, "datetime.date") + + +class MimesisDateTimeGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return datetimes.""" + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, DateTime): + return [] + return MimesisDateTimeGenerator.make_singleton( + column, engine, "datetime.datetime" + ) + + +class MimesisTimeGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return times.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Time): + return [] + return [MimesisGenerator("datetime.time")] + + +class MimesisIntegerGeneratorFactory(GeneratorFactory): + """All Mimesis generators that return integers.""" + + def get_generators( + self, columns: list[Column], _engine: Engine + ) -> Sequence[Generator]: + """Get the generators appropriate to these columns.""" + if len(columns) != 1: + return [] + column = columns[0] + ct = get_column_type(column) + if not isinstance(ct, Numeric) and not isinstance(ct, Integer): + return [] + return [MimesisGenerator("person.weight")] diff --git a/datafaker/generators/partitioned.py b/datafaker/generators/partitioned.py new file mode 100644 index 0000000..395f261 --- /dev/null +++ b/datafaker/generators/partitioned.py @@ -0,0 +1,514 @@ +"""Powerful generators for numbers, choices and related missingness.""" + +from dataclasses import dataclass +from itertools import chain, combinations +from typing import Any, Iterable, Sequence, Union + +import sqlalchemy +from datafaker.generators.base import ( + Generator, + dist_gen, + get_column_type, +) +from datafaker.generators.continuous import ( + MultivariateNormalGeneratorFactory, +) +from sqlalchemy import Column, Connection, Engine, RowMapping, text +from sqlalchemy.types import Integer, Numeric + +from datafaker.utils import T, logger + +NumericType = Union[int, float] + +# How many distinct values can we have before we consider a +# choice distribution to be infeasible? +MAXIMUM_CHOICES = 500 + + +def text_list(items: Iterable[str]) -> str: + """Concatenate the items with commas and one "and".""" + item_i = iter(items) + try: + last_item = next(item_i) + except StopIteration: + return "" + try: + so_far = next(item_i) + except StopIteration: + return last_item + for item in item_i: + so_far += ", " + last_item + last_item = item + return so_far + " and " + last_item + + +@dataclass +class RowPartition: + """A partition where all the rows have the same pattern of NULLs.""" + + query: str + # list of numeric columns + included_numeric: list[Column] + # map of indices to column names that are being grouped by. + # The indices are indices of where they need to be inserted into + # the generator outputs. + included_choice: dict[int, str] + # map of column names to clause that defines the partition + # such as "mycolumn IS NULL" + excluded_columns: dict[str, str] + # map of constant outputs that need to be inserted into the + # list of included column values (so once the generator has + # been run and the included_choice values have been + # added): {index: value} + constant_outputs: dict[int, Any] + # The actual covariates from the source database + covariates: Sequence[RowMapping] + + def comment(self) -> str: + """Make an appropriate comment for this partition.""" + caveat = "" + if self.included_choice: + caveat = f" (for each possible value of {text_list(self.included_choice.values())})" + if not self.included_numeric: + return f"Number of rows for which {text_list(self.excluded_columns.values())}{caveat}" + if not self.excluded_columns: + where = "" + else: + where = f" where {text_list(self.excluded_columns.values())}" + if len(self.included_numeric) == 1: + return ( + f"Mean and variance for column {self.included_numeric[0].name}{where}." + ) + return ( + "Means and covariate matrix for the columns " + f"{text_list(col.name for col in self.included_numeric)}{where}{caveat} so that we can" + " produce the relatedness between these in the fake data." + ) + + +class NullPartitionedNormalGenerator(Generator): + """ + A generator of mixed numeric and non-numeric data. + + Generates data that matches the source data in + missingness, choice of non-numeric data and numeric + data. + + For the numeric data to be generated, samples of rows for each + combination of non-numeric values and missingness. If any such + combination has only one line in the source data (or sample of + the source data if sampling), it will not be generated as a + covariate matrix cannot be generated from one source row + (although if the data is all non-numeric values and nulls, single + rows are used because no covariate matrix is required for this). + """ + + def __init__( + self, + query_name: str, + partitions: dict[int, RowPartition], + function_name: str = "grouped_multivariate_lognormal", + name_suffix: str | None = None, + partition_count_query: str | None = None, + partition_counts: Iterable[RowMapping] = [], + partition_count_comment: str | None = None, + ): + """Initialise a NullPartitionedNormalGenerator.""" + self._query_name = query_name + self._partitions = partitions + self._function_name = function_name + self._partition_count_query = partition_count_query + self._partition_counts = [dict(pc) for pc in partition_counts] + self._partition_count_comment = partition_count_comment + if name_suffix: + self._name = f"null-partitioned {function_name} [{name_suffix}]" + else: + self._name = f"null-partitioned {function_name}" + + def name(self) -> str: + """Get the name of the generator.""" + return self._name + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "dist_gen.alternatives" + + def _nominal_kwargs_with_combinations( + self, index: int, partition: RowPartition + ) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml`` for a single partition.""" + count = ( + 'sum(r["count"] for r in' + f' SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"])' + ) + if not partition.included_numeric and not partition.included_choice: + return { + "count": count, + "name": '"constant"', + "params": {"value": [None] * len(partition.constant_outputs)}, + } + covariates = { + "covs": f'SRC_STATS["auto__cov__{self._query_name}__alt_{index}"]["results"]' + } + if not partition.constant_outputs: + return { + "count": count, + "name": f'"{self._function_name}"', + "params": covariates, + } + return { + "count": count, + "name": '"with_constants_at"', + "params": { + "constants_at": partition.constant_outputs, + "subgen": f'"{self._function_name}"', + "params": covariates, + }, + } + + def _count_query_name(self) -> str: + return f"auto__cov__{self._query_name}__counts" + + def nominal_kwargs(self) -> dict[str, Any]: + """Get the arguments to be entered into ``config.yaml``.""" + return { + "alternative_configs": [ + self._nominal_kwargs_with_combinations(index, self._partitions[index]) + for index in range(len(self._partitions)) + ], + "counts": f'SRC_STATS["{self._count_query_name()}"]["results"]', + } + + def custom_queries(self) -> dict[str, Any]: + """Get the queries the generators need to call.""" + partitions = { + f"auto__cov__{self._query_name}__alt_{index}": { + "comment": partition.comment(), + "query": partition.query, + } + for index, partition in self._partitions.items() + } + if not self._partition_count_query: + return partitions + return { + self._count_query_name(): { + "comment": self._partition_count_comment, + "query": self._partition_count_query, + }, + **partitions, + } + + def _actual_kwargs_with_combinations( + self, partition: RowPartition + ) -> dict[str, Any]: + count = sum(row["count"] for row in partition.covariates) + if not partition.included_numeric and not partition.included_choice: + return { + "count": count, + "name": "constant", + "params": {"value": [None] * len(partition.excluded_columns)}, + } + covariates = { + "covs": partition.covariates, + } + if not partition.constant_outputs: + return { + "count": count, + "name": self._function_name, + "params": covariates, + } + return { + "count": count, + "name": "with_constants_at", + "params": { + "constants_at": partition.constant_outputs, + "subgen": self._function_name, + "params": covariates, + }, + } + + def actual_kwargs(self) -> dict[str, Any]: + """Get the kwargs (summary statistics) this generator was instantiated with.""" + return { + "alternative_configs": [ + self._actual_kwargs_with_combinations(self._partitions[index]) + for index in range(len(self._partitions)) + ], + "counts": self._partition_counts, + } + + def generate_data(self, count: int) -> list[Any]: + """Generate 'count' random data points for this column.""" + kwargs = self.actual_kwargs() + return [dist_gen.alternatives(**kwargs) for _ in range(count)] + + def fit(self, default: float = -1) -> float: + """Get this generator's fit against the real data.""" + return default + + +def is_numeric(col: Column) -> bool: + """Test if this column stores a numeric value.""" + ct = get_column_type(col) + return isinstance(ct, (Numeric, Integer)) and not col.foreign_keys + + +def powerset(xs: list[T]) -> Iterable[Iterable[T]]: + """Get a list of all sublists of ``input``.""" + return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) + + +@dataclass +class NullableColumn: + """A reference to a nullable column whose nullability is part of a partitioning.""" + + column: Column + # The bit (power of two) of the number of the partition in the partition sizes list + bitmask: int + + +class NullPatternPartition: + """Get the definition of a partition (in other words, what makes it not another partition).""" + + def __init__( + self, columns: Iterable[Column], partition_nonnulls: Iterable[NullableColumn] + ): + """Initialise a pattern of nulls which can be queried for.""" + self.index = sum(nc.bitmask for nc in partition_nonnulls) + nonnull_columns = {nc.column.name for nc in partition_nonnulls} + self.included_numeric: list[Column] = [] + self.included_choice: dict[int, str] = {} + self.group_by_clause = "" + self.constant_clauses = "" + self.constants = "" + self.excluded: dict[str, str] = {} + self.predicates: list[str] = [] + self.nones: dict[int, None] = {} + for col_index, column in enumerate(columns): + col_name = column.name + if col_name in nonnull_columns or not column.nullable: + if is_numeric(column): + self.included_numeric.append(column) + else: + index = len(self.included_numeric) + len(self.included_choice) + self.included_choice[index] = col_name + if self.group_by_clause: + self.group_by_clause += ", " + col_name + else: + self.group_by_clause = " GROUP BY " + col_name + self.constant_clauses += f", _q.{col_name} AS k{index}" + self.constants += ", " + col_name + else: + self.excluded[col_name] = f"{col_name} IS NULL" + self.predicates.append(f"{col_name} IS NULL") + self.nones[col_index] = None + + +class NullPartitionedNormalGeneratorFactory(MultivariateNormalGeneratorFactory): + """Produces null partitioned generators, for complex interdependent data.""" + + SAMPLE_COUNT = MAXIMUM_CHOICES + SUPPRESS_COUNT = 5 + EMPTY_RESULT = [ + RowMapping( + parent=sqlalchemy.engine.result.SimpleResultMetaData(["count"]), + processors=None, + key_to_index={"count": 0}, + data=(0,), + ) + ] + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "grouped_multivariate_normal" + + def query_predicate(self, column: Column) -> str: + """Get a SQL expression that is true when ``column`` is available for analysis.""" + if is_numeric(column): + # x <> x + 1 ensures that x is not infinity or NaN + return f"COALESCE({column.name} <> {column.name} + 1, FALSE)" + return f"{column.name} IS NOT NULL" + + def query_var(self, column: str) -> str: + """Return the expression we are querying for in this column.""" + return column + + def get_nullable_columns(self, columns: list[Column]) -> list[NullableColumn]: + """Get a list of nullable columns together with bitmasks.""" + out: list[NullableColumn] = [] + for col in columns: + if col.nullable: + out.append( + NullableColumn( + column=col, + bitmask=2 ** len(out), + ) + ) + return out + + def get_partition_count_query( + self, ncs: list[NullableColumn], table: str, where: str | None = None + ) -> str: + """ + Get a SQL expression returning columns ``count`` and ``index``. + + Each row returned represents one of the null pattern partitions. + ``index`` is the bitmask of all those nullable columns that are not null for + this partition, and ``count`` is the total number of rows in this partition. + """ + index_exp = " + ".join( + f"CASE WHEN {self.query_predicate(nc.column)} THEN {nc.bitmask} ELSE 0 END" + for nc in ncs + ) + if where is None: + return f'SELECT COUNT(*) AS count, {index_exp} AS "index" FROM {table} GROUP BY "index"' + return ( + 'SELECT count, "index" FROM (SELECT COUNT(*) AS count,' + f' {index_exp} AS "index"' + f' FROM {table} GROUP BY "index") AS _q {where}' + ) + + def get_generators( + self, columns: list[Column], engine: Engine + ) -> Sequence[Generator]: + """Get any appropriate generators for these columns.""" + if len(columns) < 2: + return [] + nullable_columns = self.get_nullable_columns(columns) + if not nullable_columns: + return [] + table = columns[0].table.name + query_name = f"{table}__{columns[0].name}" + # Partitions for minimal suppression and no sampling + row_partitions_maximal: dict[int, RowPartition] = {} + # Partitions for normal suppression and severe sampling + row_partitions_ss: dict[int, RowPartition] = {} + for partition_nonnulls in powerset(nullable_columns): + partition_def = NullPatternPartition(columns, partition_nonnulls) + query = self.query( + table=table, + columns=partition_def.included_numeric, + predicates=partition_def.predicates, + group_by_clause=partition_def.group_by_clause, + constants=partition_def.constants, + constant_clauses=partition_def.constant_clauses, + ) + row_partitions_maximal[partition_def.index] = RowPartition( + query, + partition_def.included_numeric, + partition_def.included_choice, + partition_def.excluded, + partition_def.nones, + [], + ) + query = self.query( + table=table, + columns=partition_def.included_numeric, + predicates=partition_def.predicates, + group_by_clause=partition_def.group_by_clause, + constants=partition_def.constants, + constant_clauses=partition_def.constant_clauses, + suppress_count=self.SUPPRESS_COUNT, + sample_count=self.SAMPLE_COUNT, + ) + row_partitions_ss[partition_def.index] = RowPartition( + query, + partition_def.included_numeric, + partition_def.included_choice, + partition_def.excluded, + partition_def.nones, + [], + ) + gens: list[Generator] = [] + try: + with engine.connect() as connection: + partition_query_max = self.get_partition_count_query( + nullable_columns, table + ) + partition_count_max_results = ( + connection.execute(text(partition_query_max)).mappings().fetchall() + ) + count_comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table} being null" + ) + if self._execute_partition_queries(connection, row_partitions_maximal): + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_maximal, + self.function_name(), + partition_count_query=partition_query_max, + partition_counts=partition_count_max_results, + partition_count_comment=count_comment, + ) + ) + partition_query_ss = self.get_partition_count_query( + nullable_columns, + table, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + ) + partition_count_ss_results = ( + connection.execute(text(partition_query_ss)).mappings().fetchall() + ) + if self._execute_partition_queries(connection, row_partitions_ss): + gens.append( + NullPartitionedNormalGenerator( + query_name, + row_partitions_ss, + self.function_name(), + name_suffix="sampled and suppressed", + partition_count_query=partition_query_ss, + partition_counts=partition_count_ss_results, + partition_count_comment=count_comment, + ) + ) + except sqlalchemy.exc.DatabaseError as exc: + logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) + return [] + return gens + + def _execute_partition_queries( + self, + connection: Connection, + partitions: dict[int, RowPartition], + ) -> bool: + """ + Execute the query in each partition, filling in the covariates. + + :return: True if all the partitions work, False if any of them fail. + """ + found_nonzero = False + for rp in partitions.values(): + covs = connection.execute(text(rp.query)).mappings().fetchall() + if not covs or covs.count == 0 or covs[0]["count"] is None: + rp.covariates = self.EMPTY_RESULT + else: + rp.covariates = covs + found_nonzero = True + return found_nonzero + + +class NullPartitionedLogNormalGeneratorFactory(NullPartitionedNormalGeneratorFactory): + """ + A generator for numeric and non-numeric columns. + + Any values could be null, the distributions of the nonnull numeric columns + depend on each other and the other non-numeric column values. + """ + + def function_name(self) -> str: + """Get the name of the generator function to call.""" + return "grouped_multivariate_lognormal" + + def query_predicate(self, column: Column) -> str: + """Get the SQL expression testing if the value in this column should be used.""" + if is_numeric(column): + # x <> x + 1 ensures that x is not infinity or NaN + return f"COALESCE({column.name} <> {column.name} + 1 AND 0 < {column.name}, FALSE)" + return f"{column.name} IS NOT NULL" + + def query_var(self, column: str) -> str: + """Get the variable or expression we are querying for this column.""" + return f"LN({column})" diff --git a/datafaker/interactive.py b/datafaker/interactive.py deleted file mode 100644 index 17efa19..0000000 --- a/datafaker/interactive.py +++ /dev/null @@ -1,2175 +0,0 @@ -"""Interactive configuration commands.""" -import cmd -import csv -import functools -import re -from abc import ABC, abstractmethod -from collections.abc import Mapping, MutableMapping, Sequence -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from types import TracebackType -from typing import Any, Callable, Iterable, Optional, Type, cast - -import sqlalchemy -from prettytable import PrettyTable -from sqlalchemy import Column, Engine, ForeignKey, MetaData, Table -from typing_extensions import Self - -from datafaker.generators import Generator, PredefinedGenerator, everything_factory -from datafaker.utils import ( - T, - create_db_engine, - fk_refers_to_ignored_table, - get_sync_engine, - logger, - primary_private_fks, - table_is_private, -) - -# Monkey patch pyreadline3 v3.5 so that it works with Python 3.13 -# Windows users can install pyreadline3 to get tab completion working. -# See https://github.com/pyreadline3/pyreadline3/issues/37 -try: - import readline - - if not hasattr(readline, "backend"): - setattr(readline, "backend", "readline") -except: - pass - - -def or_default(v: T | None, d: T) -> T: - """Return v if it isn't None, otherwise d.""" - return d if v is None else v - - -class TableType(Enum): - """Types of table to be configured.""" - - GENERATE = "generate" - IGNORE = "ignore" - VOCABULARY = "vocabulary" - PRIVATE = "private" - EMPTY = "empty" - - -TYPE_LETTER = { - TableType.GENERATE: "G", - TableType.IGNORE: "I", - TableType.VOCABULARY: "V", - TableType.PRIVATE: "P", - TableType.EMPTY: "e", -} - -TYPE_PROMPT = { - TableType.GENERATE: "(table: {}) ", - TableType.IGNORE: "(table: {} (ignore)) ", - TableType.VOCABULARY: "(table: {} (vocab)) ", - TableType.PRIVATE: "(table: {} (private)) ", - TableType.EMPTY: "(table: {} (empty))", -} - - -@dataclass -class TableEntry: - """Base class for table entries for interactive commands.""" - - name: str # name of the table - - -class AskSaveCmd(cmd.Cmd): - """Interactive shell for whether to save and quit.""" - - intro = "Do you want to save this configuration?" - prompt = "(yes/no/cancel) " - file = None - - def __init__(self) -> None: - """Initialise a save command.""" - super().__init__() - self.result = "" - - def do_yes(self, _arg: str) -> bool: - """Save the new config.yaml.""" - self.result = "yes" - return True - - def do_no(self, _arg: str) -> bool: - """Exit without saving.""" - self.result = "no" - return True - - def do_cancel(self, _arg: str) -> bool: - """Do not exit.""" - self.result = "cancel" - return True - - -def fk_column_name(fk: ForeignKey) -> str: - """Display name for a foreign key.""" - if fk_refers_to_ignored_table(fk): - return f"{fk.target_fullname} (ignored)" - return str(fk.target_fullname) - - -class DbCmd(ABC, cmd.Cmd): - """Base class for interactive configuration commands.""" - - INFO_NO_MORE_TABLES = "There are no more tables" - ERROR_ALREADY_AT_START = "Error: Already at the start" - ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" - ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" - ROW_COUNT_MSG = "Total row count: {}" - - @abstractmethod - def make_table_entry( - self, table_name: str, table_config: Mapping - ) -> TableEntry | None: - """ - Make a table entry suitable for this interactive command. - - :param name: The name of the table to make an entry for. - :param table_config: The part of the ``config.yaml`` referring to this table. - :return: The table entry or None if this table should not be interacted with. - """ - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - ): - """Initialise a DbCmd.""" - super().__init__() - self.config: MutableMapping[str, Any] = config - self.metadata = metadata - self._table_entries: list[TableEntry] = [] - tables_config: MutableMapping = config.get("tables", {}) - if not isinstance(tables_config, MutableMapping): - tables_config = {} - for name in metadata.tables.keys(): - table_config = tables_config.get(name, {}) - if not isinstance(table_config, MutableMapping): - table_config = {} - entry = self.make_table_entry(name, table_config) - if entry is not None: - self._table_entries.append(entry) - self.table_index = 0 - self.engine = create_db_engine(src_dsn, schema_name=src_schema) - - @property - def sync_engine(self) -> Engine: - """Get the synchronous version of the engine.""" - return get_sync_engine(self.engine) - - def __enter__(self) -> Self: - """Enter a ``with`` statement.""" - return self - - def __exit__( - self, - _exc_type: Optional[Type[BaseException]], - _exc_val: Optional[BaseException], - _exc_tb: Optional[TracebackType], - ) -> None: - """Dispose of this object.""" - self.engine.dispose() - - def print(self, text: str, *args: Any, **kwargs: Any) -> None: - """Print text, formatted with positional and keyword arguments.""" - print(text.format(*args, **kwargs)) - - def print_table(self, headings: list[str], rows: list[list[Any]]) -> None: - """ - Print a table. - - :param headings: List of headings for the table. - :param rows: List of rows of values. - """ - output = PrettyTable() - output.field_names = headings - for row in rows: - output.add_row(row) - print(output) - - def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: - """ - Print a table. - - :param columns: Dict of column names to the values in the column. - """ - output = PrettyTable() - row_count = max([len(col) for col in columns.values()]) - for field_name, data in columns.items(): - output.add_column(field_name, data + [None] * (row_count - len(data))) - print(output) - - def print_results(self, result: sqlalchemy.CursorResult) -> None: - """Print the rows resulting from a database query.""" - self.print_table(list(result.keys()), [list(row) for row in result.all()]) - - def ask_save(self) -> str: - """ - Ask the user if they want to save. - - :return: ``yes``, ``no`` or ``cancel``. - """ - ask = AskSaveCmd() - ask.cmdloop() - return ask.result - - @abstractmethod - def set_prompt(self) -> None: - """Set the prompt according to the current state.""" - ... - - def _set_table_index(self, index: int) -> bool: - """ - Move to a different table. - - :param index: Index of the table to move to. - :return: True if there is a table with such an index to move to. - """ - if 0 <= index and index < len(self._table_entries): - self.table_index = index - self.set_prompt() - return True - return False - - def next_table(self, report: str = "No more tables") -> bool: - """ - Move to the next table. - - :param report: The text to print if there is no next table. - :return: True if there is another table to move to. - """ - if not self._set_table_index(self.table_index + 1): - self.print(report) - return False - return True - - def table_name(self) -> str: - """Get the name of the current table.""" - return str(self._table_entries[self.table_index].name) - - def table_metadata(self) -> Table: - """Get the metadata of the current table.""" - return self.metadata.tables[self.table_name()] - - def _get_column_names(self) -> list[str]: - """Get the names of the current columns.""" - return [col.name for col in self.table_metadata().columns] - - def report_columns(self) -> None: - """Print information about the current columns.""" - self.print_table( - ["name", "type", "primary", "nullable", "foreign key"], - [ - [ - name, - str(col.type), - col.primary_key, - col.nullable, - ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), - ] - for name, col in self.table_metadata().columns.items() - ], - ) - - def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: - """Get the configuration of the named table.""" - ts = self.config.get("tables", None) - if not isinstance(ts, MutableMapping): - return {} - t = ts.get(table_name) - return t if isinstance(t, MutableMapping) else {} - - def set_table_config( - self, table_name: str, config: MutableMapping[str, Any] - ) -> None: - """Set the configuration of the named table.""" - ts = self.config.get("tables", None) - if not isinstance(ts, MutableMapping): - self.config["tables"] = {table_name: config} - return - ts[table_name] = config - - def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: - """Remove all source stats with the given prefix from the configuration.""" - src_stats = self.config.get("src-stats", []) - new_src_stats = [] - for stat in src_stats: - if not stat.get("name", "").startswith(prefix): - new_src_stats.append(stat) - self.config["src-stats"] = new_src_stats - return new_src_stats - - def get_nonnull_columns(self, table_name: str) -> list[str]: - """Get the names of the nullable columns in the named table.""" - metadata_table = self.metadata.tables[table_name] - return [ - str(name) - for name, column in metadata_table.columns.items() - if column.nullable - ] - - def find_entry_index_by_table_name(self, table_name: str) -> int | None: - """Get the index of the table entry of the named table.""" - return next( - ( - i - for i, entry in enumerate(self._table_entries) - if entry.name == table_name - ), - None, - ) - - def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: - """Get the table entry of the named table.""" - for e in self._table_entries: - if e.name == table_name: - return e - return None - - def do_counts(self, _arg: str) -> None: - """Report the column names with the counts of nulls in them.""" - if len(self._table_entries) <= self.table_index: - return - table_name = self.table_name() - nonnull_columns = self.get_nonnull_columns(table_name) - colcounts = [", COUNT({0}) AS {0}".format(nnc) for nnc in nonnull_columns] - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" - ) - ).first() - if result is None: - self.print("Could not count rows in table {0}", table_name) - return - row_count = result.row_count - self.print(self.ROW_COUNT_MSG, row_count) - self.print_table( - ["Column", "NULL count"], - [ - [name, row_count - count] - for name, count in result._mapping.items() - if name != "row_count" - ], - ) - - def do_select(self, arg: str) -> None: - """Run a select query over the database and show the first 50 results.""" - max_select_rows = 50 - with self.sync_engine.connect() as connection: - try: - result = connection.execute(sqlalchemy.text("SELECT " + arg)) - except sqlalchemy.exc.DatabaseError as exc: - self.print("Failed to execute: {}", exc) - return - row_count = result.rowcount - self.print(self.ROW_COUNT_MSG, row_count) - if 50 < row_count: - self.print("Showing the first {} rows", max_select_rows) - fields = list(result.keys()) - rows = [row._tuple() for row in result.fetchmany(max_select_rows)] - self.print_table(fields, rows) - - def do_peek(self, arg: str) -> None: - """ - View some data from the current table. - - Use 'peek col1 col2 col3' to see a sample of values from - columns col1, col2 and col3 in the current table. - Use 'peek' to see a sample of the current column(s). - Rows that are enitrely null are suppressed. - """ - max_peek_rows = 25 - if len(self._table_entries) <= self.table_index: - return - table_name = self.table_name() - col_names = arg.split() - if not col_names: - col_names = self._get_column_names() - nonnulls = [cn + " IS NOT NULL" for cn in col_names] - with self.sync_engine.connect() as connection: - cols = (",".join(col_names),) - where = ("WHERE" if nonnulls else "",) - nonnull = (" OR ".join(nonnulls),) - query = sqlalchemy.text( - f"SELECT {cols} FROM {table_name} {where} {nonnull}" - f" ORDER BY RANDOM() LIMIT {max_peek_rows}" - ) - try: - result = connection.execute(query) - except Exception as exc: - self.print(f'SQL query "{query}" caused exception {exc}') - return - rows = [row._tuple() for row in result.fetchmany(max_peek_rows)] - self.print_table(list(result.keys()), rows) - - def complete_peek( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Completions for the ``peek`` command.""" - if len(self._table_entries) <= self.table_index: - return [] - return [ - col for col in self.table_metadata().columns.keys() if col.startswith(text) - ] - - -@dataclass -class TableCmdTableEntry(TableEntry): - """Table entry for the table command shell.""" - - old_type: TableType - new_type: TableType - - -class TableCmd(DbCmd): - """Command shell allowing the user to set the type of each table.""" - - intro = ( - "Interactive table configuration (ignore," - " vocabulary, private, generate or empty). Type ? for help.\n" - ) - doc_leader = """Use the commands 'ignore', 'vocabulary', -'private', 'empty' or 'generate' to set the table's type. Use 'next' or -'previous' to change table. Use 'tables' and 'columns' for -information about the database. Use 'data', 'peek', 'select' or -'count' to see some data contained in the current table. Use 'quit' -to exit this program.""" - prompt = "(tableconf) " - file = None - WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( - "Vocabulary table {0} references non-vocabulary table {1}" - ) - WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( - "Empty table {1} referenced from non-empty table {0}. {1} will need stories." - ) - WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" - WARNING_TEXT_POTENTIAL_PROBLEMS = ( - "NOTE: The following table types might cause problems later:" - ) - NOTE_TEXT_NO_CHANGES = "You have made no changes." - NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" - - def make_table_entry( - self, table_name: str, table: Mapping - ) -> TableCmdTableEntry | None: - """ - Make a table entry for the named table. - - :param name: The name of the table. - :param table: The part of ``config.yaml`` corresponding to this table. - :return: The newly-constructed table entry. - """ - if table.get("ignore", False): - return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) - if table.get("vocabulary_table", False): - return TableCmdTableEntry( - table_name, TableType.VOCABULARY, TableType.VOCABULARY - ) - if table.get("primary_private", False): - return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) - if table.get("num_rows_per_pass", 1) == 0: - return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) - return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - ) -> None: - """Initialise a TableCmd.""" - super().__init__(src_dsn, src_schema, metadata, config) - self.set_prompt() - - @property - def table_entries(self) -> list[TableCmdTableEntry]: - """Get the list of table entries.""" - return cast(list[TableCmdTableEntry], self._table_entries) - - def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: - """Get the table entry of the table with the given name.""" - entry = super()._find_entry_by_table_name(table_name) - if entry is None: - return None - return cast(TableCmdTableEntry, entry) - - def set_prompt(self) -> None: - """Set the prompt according to the current table and its type.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) - else: - self.prompt = "(table) " - - def set_type(self, t_type: TableType) -> None: - """Set the type of the current table.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - entry.new_type = t_type - - def _copy_entries(self) -> None: - """Alter the configuration to match the new table entries.""" - for entry in self.table_entries: - if entry.old_type != entry.new_type: - table = self.get_table_config(entry.name) - if ( - entry.old_type == TableType.EMPTY - and table.get("num_rows_per_pass", 1) == 0 - ): - table["num_rows_per_pass"] = 1 - if entry.new_type == TableType.IGNORE: - table["ignore"] = True - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - elif entry.new_type == TableType.VOCABULARY: - table.pop("ignore", None) - table["vocabulary_table"] = True - table.pop("primary_private", None) - elif entry.new_type == TableType.PRIVATE: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table["primary_private"] = True - elif entry.new_type == TableType.EMPTY: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - table["num_rows_per_pass"] = 0 - else: - table.pop("ignore", None) - table.pop("vocabulary_table", None) - table.pop("primary_private", None) - self.set_table_config(entry.name, table) - - def _get_referenced_tables(self, from_table_name: str) -> set[str]: - """Get all the tables referenced by this table's foreign keys.""" - from_meta = self.metadata.tables[from_table_name] - return { - fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys - } - - def _sanity_check_failures(self) -> list[tuple[str, str, str]]: - """Find tables that reference each other that should not given their types.""" - failures = [] - for from_entry in self.table_entries: - from_t = from_entry.new_type - if from_t == TableType.VOCABULARY: - referenced = self._get_referenced_tables(from_entry.name) - for ref in referenced: - to_entry = self._find_entry_by_table_name(ref) - if ( - to_entry is not None - and to_entry.new_type != TableType.VOCABULARY - ): - failures.append( - ( - self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - from_entry.name, - to_entry.name, - ) - ) - return failures - - def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: - """Find tables that reference each other that might cause problems given their types.""" - warnings = [] - for from_entry in self.table_entries: - from_t = from_entry.new_type - if from_t in {TableType.GENERATE, TableType.PRIVATE}: - referenced = self._get_referenced_tables(from_entry.name) - for ref in referenced: - to_entry = self._find_entry_by_table_name(ref) - if to_entry is not None and to_entry.new_type in { - TableType.EMPTY, - TableType.IGNORE, - }: - warnings.append( - ( - self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - from_entry.name, - to_entry.name, - ) - ) - return warnings - - def do_quit(self, _arg: str) -> bool: - """Check the updates, save them if desired and quit the configurer.""" - count = 0 - for entry in self.table_entries: - if entry.old_type != entry.new_type: - count += 1 - self.print( - self.NOTE_TEXT_CHANGING, - entry.name, - entry.old_type.value, - entry.new_type.value, - ) - if count == 0: - self.print(self.NOTE_TEXT_NO_CHANGES) - failures = self._sanity_check_failures() - if failures: - self.print(self.WARNING_TEXT_PROBLEMS_EXIST) - for text, from_t, to_t in failures: - self.print(text, from_t, to_t) - warnings = self._sanity_check_warnings() - if warnings: - self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) - for text, from_t, to_t in warnings: - self.print(text, from_t, to_t) - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, _arg: str) -> None: - """List the tables with their types.""" - for entry in self.table_entries: - old = entry.old_type - new = entry.new_type - becomes = " " if old == new else "->" + TYPE_LETTER[new] - self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) - - def do_next(self, arg: str) -> None: - """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" - if arg: - # Find the index of the table called _arg, if any - index = self.find_entry_index_by_table_name(arg) - if index is None: - self.print(self.ERROR_NO_SUCH_TABLE, arg) - return - self._set_table_index(index) - return - self.next_table(self.INFO_NO_MORE_TABLES) - - def complete_next( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Get the completions for tables and columns.""" - return [ - entry.name for entry in self.table_entries if entry.name.startswith(text) - ] - - def do_previous(self, _arg: str) -> None: - """Go to the previous table.""" - if not self._set_table_index(self.table_index - 1): - self.print(self.ERROR_ALREADY_AT_START) - - def do_ignore(self, _arg: str) -> None: - """Set the current table as ignored, and go to the next table.""" - self.set_type(TableType.IGNORE) - self.print("Table {} set as ignored", self.table_name()) - self.next_table() - - def do_vocabulary(self, _arg: str) -> None: - """Set the current table as a vocabulary table, and go to the next table.""" - self.set_type(TableType.VOCABULARY) - self.print("Table {} set to be a vocabulary table", self.table_name()) - self.next_table() - - def do_private(self, _arg: str) -> None: - """Set the current table as a primary private table (such as the table of patients).""" - self.set_type(TableType.PRIVATE) - self.print("Table {} set to be a primary private table", self.table_name()) - self.next_table() - - def do_generate(self, _arg: str) -> None: - """Set the current table as to be generated, and go to the next table.""" - self.set_type(TableType.GENERATE) - self.print("Table {} generate", self.table_name()) - self.next_table() - - def do_empty(self, _arg: str) -> None: - """Set the current table as empty; no generators will be run for it.""" - self.set_type(TableType.EMPTY) - self.print("Table {} empty", self.table_name()) - self.next_table() - - def do_columns(self, _arg: str) -> None: - """Report the column names and metadata.""" - self.report_columns() - - def do_data(self, arg: str) -> None: - """ - Report some data. - - 'data' = report a random ten lines, - 'data 20' = report a random 20 lines, - 'data 20 ColumnName' = report a random twenty entries from ColumnName, - 'data 20 ColumnName 30' = report a random twenty entries from ColumnName of length at least 30, - """ - args = arg.split() - column = None - number = None - arg_index = 0 - min_length = 0 - table_metadata = self.table_metadata() - if arg_index < len(args) and args[arg_index].isdigit(): - number = int(args[arg_index]) - arg_index += 1 - if arg_index < len(args) and args[arg_index] in table_metadata.columns: - column = args[arg_index] - arg_index += 1 - if arg_index < len(args) and args[arg_index].isdigit(): - min_length = int(args[arg_index]) - arg_index += 1 - if arg_index != len(args): - self.print( - """Did not understand these arguments -The format is 'data [entries] [column-name [minimum-length]]' where [] means optional text. -Type 'columns' to find out valid column names for this table. -Type 'help data' for examples.""" - ) - return - if column is None: - if number is None: - number = 10 - self.print_row_data(number) - else: - if number is None: - number = 48 - self.print_column_data(column, number, min_length) - - def complete_data( - self, text: str, line: str, begidx: int, _endidx: int - ) -> list[str]: - """Get completions for arguments to ``data``.""" - previous_parts = line[: begidx - 1].split() - if len(previous_parts) != 2: - return [] - table_metadata = self.table_metadata() - return [k for k in table_metadata.columns.keys() if k.startswith(text)] - - def print_column_data(self, column: str, count: int, min_length: int) -> None: - """ - Print a sample of data from a certain column of the current table. - - :param column: The name of the column to report on. - :param count: The number of rows to sample. - :param min_length: The minimum length of text to choose from (0 for any text). - """ - where = f"WHERE {column} IS NOT NULL" - if 0 < min_length: - where = "WHERE LENGTH({column}) >= {len}".format( - column=column, - len=min_length, - ) - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - "SELECT {column} FROM {table} {where} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - column=column, - count=count, - where=where, - ) - ) - ) - self.columnize([str(x[0]) for x in result.all()]) - - def print_row_data(self, count: int) -> None: - """ - Print a sample or rows from the current table. - - :param count: The number of rows to report. - """ - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - "SELECT * FROM {table} ORDER BY RANDOM() LIMIT {count}".format( - table=self.table_name(), - count=count, - ) - ) - ) - if result is None: - self.print("No rows in this table!") - return - self.print_results(result) - - -def update_config_tables( - src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping -) -> Mapping[str, Any]: - """Ask the user to specify what should happen to each table.""" - with TableCmd(src_dsn, src_schema, metadata, config) as tc: - tc.cmdloop() - return tc.config - - -@dataclass -class MissingnessType: - """The functions required for applying missingness.""" - - SAMPLED = "column_presence.sampled" - SAMPLED_QUERY = ( - "SELECT COUNT(*) AS row_count, {result_names} FROM " - "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" - " AS __t GROUP BY {result_names}" - ) - name: str - query: str - comment: str - columns: list[str] - - @classmethod - def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: - """ - Construct a query to make a sampling of the named rows of the table. - - :param table: The name of the table to sample. - :param count: The number of samples to get. - :param column_names: The columns to fetch. - :return: The SQL query to do the sampling. - """ - result_names = ", ".join(["{0}__is_null".format(c) for c in column_names]) - column_is_nulls = ", ".join( - ["{0} IS NULL AS {0}__is_null".format(c) for c in column_names] - ) - return cls.SAMPLED_QUERY.format( - result_names=result_names, - column_is_nulls=column_is_nulls, - table=table, - count=count, - ) - - -@dataclass -class MissingnessCmdTableEntry(TableEntry): - """Table entry for the missingness command shell.""" - - old_type: MissingnessType - new_type: MissingnessType | None - - -class MissingnessCmd(DbCmd): - """ - Interactive shell for the user to set missingness. - - Can only be used for Missingness Completely At Random. - """ - - intro = "Interactive missingness configuration. Type ? for help.\n" - doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for -the current table. Use commands 'next' and 'previous' to change the -current table. Use 'tables' to list the tables and 'count' to show -how many NULLs exist in each column. Use 'peek' or 'select' to see -data from the database. Use 'quit' to exit this tool.""" - prompt = "(missingness) " - file = None - PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') - - def find_missingness_query( - self, missingness_generator: Mapping - ) -> tuple[str, str] | None: - """Find query and comment from src-stats for the passed missingness generator.""" - kwargs = missingness_generator.get("kwargs", {}) - patterns = kwargs.get("patterns", "") - pattern_match = self.PATTERN_RE.match(patterns) - if pattern_match: - key = pattern_match.group(1) - for src_stat in self.config["src-stats"]: - if src_stat.get("name") == key: - query = src_stat.get("query", None) - if type(query) is not str: - return None - return (query, src_stat.get("comment", "")) - return None - - def make_table_entry( - self, table_name: str, table_config: Mapping - ) -> MissingnessCmdTableEntry | None: - """ - Make a table entry for a particular table. - - :param name: The name of the table to make an entry for. - :param table: The part of ``config.yaml`` relating to this table. - :return: The newly-constructed table entry. - """ - if table_config.get("ignore", False): - return None - if table_config.get("vocabulary_table", False): - return None - if table_config.get("num_rows_per_pass", 1) == 0: - return None - mgs = table_config.get("missingness_generators", []) - old = None - nonnull_columns = self.get_nonnull_columns(table_name) - if not nonnull_columns: - return None - if not mgs: - old = MissingnessType( - name="none", - query="", - comment="", - columns=[], - ) - elif len(mgs) == 1: - mg = mgs[0] - mg_name = mg.get("name", None) - if isinstance(mg_name, str): - query_comment = self.find_missingness_query(mg) - if query_comment is not None: - (query, comment) = query_comment - old = MissingnessType( - name=mg_name, - query=query, - comment=comment, - columns=mg.get("columns_assigned", []), - ) - if old is None: - return None - return MissingnessCmdTableEntry( - name=table_name, - old_type=old, - new_type=old, - ) - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping, - ): - """ - Initialise a MissingnessCmd. - - :param src_dsn: connection string for the source database. - :param src_schema: schema name for the source database. - :param metadata: SQLAlchemy metadata for the source database. - :param config: Configuration from the ``config.yaml`` file. - """ - super().__init__(src_dsn, src_schema, metadata, config) - self.set_prompt() - - @property - def table_entries(self) -> list[MissingnessCmdTableEntry]: - """Get the table entries list.""" - return cast(list[MissingnessCmdTableEntry], self._table_entries) - - def _find_entry_by_table_name( - self, table_name: str - ) -> MissingnessCmdTableEntry | None: - """Find the table entry given the table name.""" - entry = super()._find_entry_by_table_name(table_name) - if entry is None: - return None - return cast(MissingnessCmdTableEntry, entry) - - def set_prompt(self) -> None: - """Set the prompt according to the current table and missingness.""" - if self.table_index < len(self.table_entries): - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - nt = entry.new_type - if nt is None: - self.prompt = f"(missingness for {entry.name}) " - else: - self.prompt = f"(missingness for {entry.name}: {nt.name}) " - else: - self.prompt = "(missingness) " - - def set_type(self, t_type: MissingnessType) -> None: - """Set the missingness of the current table.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - entry.new_type = t_type - - def _copy_entries(self) -> None: - """Set the new missingness into the configuration.""" - src_stats = self._remove_prefix_src_stats("missing_auto__") - for entry in self.table_entries: - table = self.get_table_config(entry.name) - if entry.new_type is None or entry.new_type.name == "none": - table.pop("missingness_generators", None) - else: - src_stat_key = f"missing_auto__{entry.name}__0" - table["missingness_generators"] = [ - { - "name": entry.new_type.name, - "kwargs": { - "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' - }, - "columns": entry.new_type.columns, - } - ] - src_stats.append( - { - "name": src_stat_key, - "query": entry.new_type.query, - "comments": [] - if entry.new_type.comment is None - else [entry.new_type.comment], - } - ) - self.set_table_config(entry.name, table) - - def do_quit(self, _arg: str) -> bool: - """Check the updates, save them if desired and quit the configurer.""" - count = 0 - for entry in self.table_entries: - if entry.old_type != entry.new_type: - count += 1 - if entry.old_type is None: - self.print( - "Putting generator {0} on table {1}", - entry.name, - entry.new_type.name, - ) - elif entry.new_type is None: - self.print( - "Deleting generator {1} from table {0}", - entry.name, - entry.old_type.name, - ) - else: - self.print( - "Changing {0} from {1} to {2}", - entry.name, - entry.old_type.name, - entry.new_type.name, - ) - if count == 0: - self.print("You have made no changes.") - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, _arg: str) -> None: - """List the tables with their types.""" - for entry in self.table_entries: - old = "-" if entry.old_type is None else entry.old_type.name - new = "-" if entry.new_type is None else entry.new_type.name - desc = new if old == new else f"{old}->{new}" - self.print("{0} {1}", entry.name, desc) - - def do_next(self, arg: str) -> None: - """ - Go to the next table, or a specified table. - - 'next' = go to the next table, 'next tablename' = go to table 'tablename' - """ - if arg: - # Find the index of the table called _arg, if any - index = next( - (i for i, entry in enumerate(self.table_entries) if entry.name == arg), - None, - ) - if index is None: - self.print(self.ERROR_NO_SUCH_TABLE, arg) - return - self._set_table_index(index) - return - self.next_table(self.INFO_NO_MORE_TABLES) - - def complete_next( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Get completions for tables and columns.""" - return [ - entry.name for entry in self.table_entries if entry.name.startswith(text) - ] - - def do_previous(self, _arg: str) -> None: - """Go to the previous table.""" - if not self._set_table_index(self.table_index - 1): - self.print(self.ERROR_ALREADY_AT_START) - - def _set_type(self, name: str, query: str, comment: str) -> None: - """Set the current table entry's query.""" - if len(self.table_entries) <= self.table_index: - return - entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] - entry.new_type = MissingnessType( - name=name, - query=query, - comment=comment, - columns=self.get_nonnull_columns(entry.name), - ) - - def _set_none(self) -> None: - """Set the current table to have no missingness applied.""" - if len(self.table_entries) <= self.table_index: - return - self.table_entries[self.table_index].new_type = None - - def do_sampled(self, arg: str) -> None: - """ - Set the current table missingness as 'sampled', and go to the next table. - - 'sampled 3000' means sample 3000 rows at random and choose the - missingness to be the same as one of those 3000 at random. - 'sampled' means the same, but with a default number of rows sampled (1000). - """ - if len(self.table_entries) <= self.table_index: - self.print("Error! not on a table") - return - entry = self.table_entries[self.table_index] - if arg == "": - count = 1000 - elif arg.isdecimal(): - count = int(arg) - else: - self.print( - "Error: sampled can be used alone or with an integer argument. {0} is not permitted", - arg, - ) - return - self._set_type( - MissingnessType.SAMPLED, - MissingnessType.sampled_query( - entry.name, - count, - self.get_nonnull_columns(entry.name), - ), - ( - "The missingness patterns and how often they appear in a" - f" sample of {count} from table {entry.name}" - ), - ) - self.print("Table {} set to sampled missingness", self.table_name()) - self.next_table() - - def do_none(self, _arg: str) -> None: - """Set the current table to have no missingness, and go to the next table.""" - self._set_none() - self.print("Table {} set to have no missingness", self.table_name()) - self.next_table() - - -def update_missingness( - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], -) -> Mapping[str, Any]: - """ - Ask the user to update the missingness information in ``config.yaml``. - - :param src_dsn: The connection string for the source database. - :param src_schema: The name of the source database schema (or None - for the default). - :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. - :param config: The starting configuration, - :return: The updated configuration. - """ - with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: - mc.cmdloop() - return mc.config - - -@dataclass -class GeneratorInfo: - """A generator and the columns it assigns to.""" - - columns: list[str] - gen: Generator | None - - -@dataclass -class GeneratorCmdTableEntry(TableEntry): - """ - List of generators set for a table. - - Includes the original setting and the currently configured - generators. - """ - - old_generators: list[GeneratorInfo] - new_generators: list[GeneratorInfo] - - -class GeneratorCmd(DbCmd): - """Interactive command shell for setting generators.""" - - intro = "Interactive generator configuration. Type ? for help.\n" - doc_leader = """Use command 'propose' for a list of generators applicable to the -current column, then command 'compare' to see how these perform -against the source data, then command 'set' to choose your favourite. -Use 'unset' to remove the column's generator. Use commands 'next' and -'previous' to change which column we are examining. Use 'info' -for useful information about the current column. Use 'tables' and -'list' to see available tables and columns. Use 'columns' to see -information about the columns in the current table. Use 'peek', -'count' or 'select' to fetch data from the source database. Use -'quit' to exit this program.""" - prompt = "(generatorconf) " - file = None - - PROPOSE_SOURCE_SAMPLE_TEXT = "Sample of actual source data: {0}..." - PROPOSE_SOURCE_EMPTY_TEXT = "Source database has no data in this column." - PROPOSE_GENERATOR_SAMPLE_TEXT = "{index}. {name}: {fit} {sample} ..." - PRIMARY_PRIVATE_TEXT = "Primary Private" - SECONDARY_PRIVATE_TEXT = "Secondary Private on columns {0}" - NOT_PRIVATE_TEXT = "Not private" - ERROR_NO_SUCH_TABLE = "No such (non-vocabulary, non-ignored) table name {0}" - ERROR_NO_SUCH_COLUMN = "No such column {0} in this table" - ERROR_COLUMN_ALREADY_MERGED = "Column {0} is already merged" - ERROR_COLUMN_ALREADY_UNMERGED = "Column {0} is not merged" - ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" - PROPOSE_NOTHING = "No proposed generators, sorry." - - SRC_STAT_RE = re.compile( - r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' - ) - - def make_table_entry( - self, table_name: str, table_config: Mapping - ) -> GeneratorCmdTableEntry | None: - """ - Make a table entry. - - :param table_name: The name of the table. - :param table: The portion of the ``config.yaml`` file describing this table. - :return: The newly constructed table entry, or None if this table is to be ignored. - """ - if table_config.get("ignore", False): - return None - if table_config.get("vocabulary_table", False): - return None - if table_config.get("num_rows_per_pass", 1) == 0: - return None - metadata_table = self.metadata.tables[table_name] - columns = [str(colname) for colname in metadata_table.columns.keys()] - column_set = frozenset(columns) - columns_assigned_so_far: set[str] = set() - - new_generator_infos: list[GeneratorInfo] = [] - old_generator_infos: list[GeneratorInfo] = [] - for rg in table_config.get("row_generators", []): - gen_name = rg.get("name", None) - if gen_name: - ca = rg.get("columns_assigned", []) - collist: list[str] = ( - [ca] if isinstance(ca, str) else [str(c) for c in ca] - ) - colset: set[str] = set(collist) - for unknown in colset - column_set: - logger.warning( - "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, - gen_name, - unknown, - ) - for mult in columns_assigned_so_far & colset: - logger.warning( - "table '%s' has column '%s' assigned to multiple times", - table_name, - mult, - ) - actual_collist = [c for c in collist if c in columns] - if actual_collist: - gen = PredefinedGenerator(table_name, rg, self.config) - new_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) - ) - old_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) - ) - columns_assigned_so_far |= colset - for colname in columns: - if colname not in columns_assigned_so_far: - new_generator_infos.append( - GeneratorInfo( - columns=[colname], - gen=None, - ) - ) - if len(new_generator_infos) == 0: - return None - return GeneratorCmdTableEntry( - name=table_name, - old_generators=old_generator_infos, - new_generators=new_generator_infos, - ) - - def __init__( - self, - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - ) -> None: - """ - Initialise a ``GeneratorCmd``. - - :param src_dsn: connection address for source database - :param src_schema: database schema name - :param metadata: SQLAlchemy metadata for the source database - :param config: Configuration loaded from ``config.yaml`` - """ - super().__init__(src_dsn, src_schema, metadata, config) - self.generators: list[Generator] | None = None - self.generator_index = 0 - self.generators_valid_columns: Optional[tuple[int, list[str]]] = None - self.set_prompt() - - @property - def table_entries(self) -> list[GeneratorCmdTableEntry]: - """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" - return cast(list[GeneratorCmdTableEntry], self._table_entries) - - def _find_entry_by_table_name( - self, table_name: str - ) -> GeneratorCmdTableEntry | None: - """ - Find the table entry by name. - - :param table_name: The name of the table to find. - :return: The table entry, or None if no such table name exists. - """ - entry = super()._find_entry_by_table_name(table_name) - if entry is None: - return None - return cast(GeneratorCmdTableEntry, entry) - - def _set_table_index(self, index: int) -> bool: - """ - Move to a new table. - - :param index: table index to move to. - """ - ret = super()._set_table_index(index) - if ret: - self.generator_index = 0 - self.set_prompt() - return ret - - def _previous_table(self) -> bool: - """ - Move to the table before the current one. - - :return: True if there is a previous table to go to. - """ - ret = self._set_table_index(self.table_index - 1) - if ret: - table = self.get_table() - if table is None: - self.print( - "Internal error! table {0} does not have any generators!", - self.table_index, - ) - return False - self.generator_index = len(table.new_generators) - 1 - else: - self.print(self.ERROR_ALREADY_AT_START) - return ret - - def get_table(self) -> GeneratorCmdTableEntry | None: - """Get the current table entry.""" - if self.table_index < len(self.table_entries): - return self.table_entries[self.table_index] - return None - - def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: - """Get a pair; the table name then the generator information.""" - if self.table_index < len(self.table_entries): - entry = self.table_entries[self.table_index] - if self.generator_index < len(entry.new_generators): - return (entry.name, entry.new_generators[self.generator_index]) - return (entry.name, None) - return (None, None) - - def _get_column_names(self) -> list[str]: - """Get the (unqualified) names for all the current columns.""" - (_, generator_info) = self._get_table_and_generator() - return generator_info.columns if generator_info else [] - - def _column_metadata(self) -> list[Column]: - """Get the metadata for all the current columns.""" - table = self.table_metadata() - if table is None: - return [] - return [table.columns[name] for name in self._get_column_names()] - - def set_prompt(self) -> None: - """Set the prompt according to the current table, column and generator.""" - (table_name, gen_info) = self._get_table_and_generator() - if table_name is None: - self.prompt = "(generators) " - return - if gen_info is None: - self.prompt = f"({table_name}) " - return - table = self.table_metadata() - columns = [ - c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns - ] - gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" - self.prompt = f"({table_name}.{','.join(columns)}{gen}) " - - def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: - """ - Remove all automatic source stats. - - We assume every source stats query whose name begins with ``auto__` - :return: The new ``src_stats`` configuration. - """ - return self._remove_prefix_src_stats("auto__") - - def _copy_entries(self) -> None: - """Set generator and query information in the configuration.""" - src_stats = self._remove_auto_src_stats() - for entry in self.table_entries: - rgs = [] - new_gens: list[Generator] = [] - for generator in entry.new_generators: - if generator.gen is not None: - new_gens.append(generator.gen) - cqs = generator.gen.custom_queries() - for cq_key, cq in cqs.items(): - src_stats.append( - { - "name": cq_key, - "query": cq["query"], - "comments": [cq["comment"]] - if "comment" in cq and cq["comment"] - else [], - } - ) - rg: dict[str, Any] = { - "name": generator.gen.function_name(), - "columns_assigned": generator.columns, - } - kwn = generator.gen.nominal_kwargs() - if kwn: - rg["kwargs"] = kwn - rgs.append(rg) - aq = self._get_aggregate_query(new_gens, entry.name) - if aq: - src_stats.append( - { - "name": f"auto__{entry.name}", - "query": aq, - "comments": [ - q["comment"] - for gen in new_gens - for q in gen.select_aggregate_clauses().values() - if "comment" in q and q["comment"] is not None - ], - } - ) - table_config = self.get_table_config(entry.name) - if rgs: - table_config["row_generators"] = rgs - elif "row_generators" in table_config: - del table_config["row_generators"] - self.set_table_config(entry.name, table_config) - self.config["src-stats"] = src_stats - - def _find_old_generator( - self, entry: GeneratorCmdTableEntry, columns: Iterable[str] - ) -> Generator | None: - """Find any generator that previously assigned to these exact same columns.""" - fc = frozenset(columns) - for gen in entry.old_generators: - if frozenset(gen.columns) == fc: - return gen.gen - return None - - def do_quit(self, arg: str) -> bool: - """Check the updates, save them if desired and quit the configurer.""" - count = 0 - for entry in self.table_entries: - header_shown = False - g_entry = cast(GeneratorCmdTableEntry, entry) - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - new_gen = None if gen is None else gen.gen - if old_gen != new_gen: - if not header_shown: - header_shown = True - self.print("Table {0}:", entry.name) - count += 1 - self.print( - "...changing {0} from {1} to {2}", - ", ".join(gen.columns), - old_gen.name() if old_gen else "nothing", - gen.gen.name() if gen.gen else "nothing", - ) - if count == 0: - self.print("You have made no changes.") - if arg in {"yes", "no"}: - reply = arg - else: - reply = self.ask_save() - if reply == "yes": - self._copy_entries() - return True - if reply == "no": - return True - return False - - def do_tables(self, _arg: str) -> None: - """List the tables.""" - for t_entry in self.table_entries: - entry = cast(GeneratorCmdTableEntry, t_entry) - gen_count = len(entry.new_generators) - how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" - self.print("{0} ({1})", entry.name, how_many) - - def do_list(self, _arg: str) -> None: - """List the generators in the current table.""" - if len(self.table_entries) <= self.table_index: - self.print("Error: no table {0}", self.table_index) - return - g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) - table = self.table_metadata() - for gen in g_entry.new_generators: - old_gen = self._find_old_generator(g_entry, gen.columns) - old = "" if old_gen is None else old_gen.name() - if old_gen == gen.gen: - becomes = "" - if old == "": - old = "(not set)" - elif gen.gen is None: - becomes = "(delete)" - else: - becomes = f"->{gen.gen.name()}" - primary = "" - if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: - primary = "[primary-key]" - self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) - - def do_columns(self, _arg: str) -> None: - """Report the column names and metadata.""" - self.report_columns() - - def do_info(self, _arg: str) -> None: - """Show information about the current column.""" - for cm in self._column_metadata(): - self.print( - "Column {0} in table {1} has type {2} ({3}).", - cm.name, - cm.table.name, - str(cm.type), - "nullable" if cm.nullable else "not nullable", - ) - if cm.primary_key: - self.print( - "It is a primary key, which usually does not need a generator (it will auto-increment)" - ) - if cm.foreign_keys: - fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] - self.print( - "It is a foreign key referencing column {0}", ", ".join(fk_names) - ) - if len(fk_names) == 1 and not cm.primary_key: - self.print( - "You do not need a generator if you just want a uniform choice over the referenced table's rows" - ) - - def _get_table_index(self, table_name: str) -> int | None: - """Get the index of the named table in the table entries list.""" - for n, entry in enumerate(self.table_entries): - if entry.name == table_name: - return n - return None - - def _get_generator_index(self, table_index: int, column_name: str) -> int | None: - """ - Get the index number of a column within the list of generators in this table. - - :param table_index: The index of the table in which to search. - :param column_name: The name of the column to search for. - :return: The index in the ``new_generators`` attribute of the table entry - containing the specified column, or None if this does not exist. - """ - entry = self.table_entries[table_index] - for n, gen in enumerate(entry.new_generators): - if column_name in gen.columns: - return n - return None - - def go_to(self, target: str) -> bool: - """ - Go to a particular column. - - :return: True on success. - """ - parts = target.split(".", 1) - table_index = self._get_table_index(parts[0]) - if table_index is None: - if len(parts) == 1: - gen_index = self._get_generator_index(self.table_index, parts[0]) - if gen_index is not None: - self.generator_index = gen_index - self.set_prompt() - return True - self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, parts[0]) - return False - gen_index = None - if 1 < len(parts) and parts[1]: - gen_index = self._get_generator_index(table_index, parts[1]) - if gen_index is None: - self.print("we cannot set the generator for column {0}", parts[1]) - return False - self._set_table_index(table_index) - if gen_index is not None: - self.generator_index = gen_index - self.set_prompt() - return True - - def do_next(self, arg: str) -> None: - """ - Go to the next generator. or a specified generator. - - Go to a named table: 'next tablename', - go to a column: 'next tablename.columnname', - or go to a column within this table: 'next columnname'. - """ - if arg: - self.go_to(arg) - else: - self._go_next() - - def do_n(self, arg: str) -> None: - """Go to the next generator, or a specified generator.""" - self.do_next(arg) - - def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: - """Complete the ``n`` command's arguments.""" - return self.complete_next(text, line, begidx, endidx) - - def _go_next(self) -> None: - """Go to the next column.""" - table = self.get_table() - if table is None: - self.print("No more tables") - return - next_gi = self.generator_index + 1 - if next_gi == len(table.new_generators): - self.next_table(self.INFO_NO_MORE_TABLES) - return - self.generator_index = next_gi - self.set_prompt() - - def complete_next( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Completions for the arguments of the ``next`` command.""" - parts = text.split(".", 1) - first_part = parts[0] - if 1 < len(parts): - column_name = parts[1] - table_index = self._get_table_index(first_part) - if table_index is None: - return [] - table_entry = self.table_entries[table_index] - return [ - f"{first_part}.{column}" - for gen in table_entry.new_generators - for column in gen.columns - if column.startswith(column_name) - ] - table_names = [ - entry.name - for entry in self.table_entries - if entry.name.startswith(first_part) - ] - if first_part in table_names: - table_names.append(f"{first_part}.") - current_table = self.get_table() - if current_table: - column_names = [ - col - for gen in current_table.new_generators - for col in gen.columns - if col.startswith(first_part) - ] - else: - column_names = [] - return table_names + column_names - - def do_previous(self, _arg: str) -> None: - """Go to the previous generator.""" - if self.generator_index == 0: - self._previous_table() - else: - self.generator_index -= 1 - self.set_prompt() - - def do_b(self, arg: str) -> None: - """Synonym for previous.""" - self.do_previous(arg) - - def _generators_valid(self) -> bool: - """Test if ``self.generators`` is still correct for the current columns.""" - return self.generators_valid_columns == ( - self.table_index, - self._get_column_names(), - ) - - def _get_generator_proposals(self) -> list[Generator]: - """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" - if not self._generators_valid(): - self.generators = None - if self.generators is None: - columns = self._column_metadata() - gens = everything_factory().get_generators(columns, self.sync_engine) - sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) - self.generators = sorted_gens - self.generators_valid_columns = ( - self.table_index, - self._get_column_names().copy(), - ) - return self.generators - - def _print_privacy(self) -> None: - """Print the privacy status of the current table.""" - table = self.table_metadata() - if table is None: - return - if table_is_private(self.config, table.name): - self.print(self.PRIMARY_PRIVATE_TEXT) - return - pfks = primary_private_fks(self.config, table) - if not pfks: - self.print(self.NOT_PRIVATE_TEXT) - return - self.print(self.SECONDARY_PRIVATE_TEXT, pfks) - - def do_compare(self, arg: str) -> None: - """ - Compare the real data with some generators. - - 'compare': just look at some source data from this column. - 'compare 5 6 10': compare a sample of the source data with a sample - from generators 5, 6 and 10. You can find out which numbers - correspond to which generators using the 'propose' command. - """ - self._print_privacy() - args = arg.split() - limit = 20 - comparison = { - "source": [ - x[0] if len(x) == 1 else ", ".join(x) - for x in self._get_column_data(limit, to_str=str) - ] - } - gens: list[Generator] = self._get_generator_proposals() - table_name = self.table_name() - for argument in args: - if argument.isdigit(): - n = int(argument) - if 0 < n <= len(gens): - gen = gens[n - 1] - comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) - self._print_values_queried(table_name, n, gen) - self.print_table_by_columns(comparison) - - def do_c(self, arg: str) -> None: - """Synonym for compare.""" - self.do_compare(arg) - - def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: - """ - Print the values queried from the database for this generator. - - :param table_name: The name of the table the generator applies to. - :param n: A number to print at the start of the output. - :param gen: The generator to report. - """ - if not gen.select_aggregate_clauses() and not gen.custom_queries(): - self.print( - "{0}. {1} requires no data from the source database.", - n, - gen.name(), - ) - else: - self.print( - "{0}. {1} requires the following data from the source database:", - n, - gen.name(), - ) - self._print_select_aggregate_query(table_name, gen) - self._print_custom_queries(gen) - - def _print_custom_queries(self, gen: Generator) -> None: - """ - Print all the custom queries and all the values they get in this case. - - :param gen: The generator to print the custom queries for. - """ - cqs = gen.custom_queries() - if not cqs: - return - cq_key2args: dict[str, Any] = {} - nominal = gen.nominal_kwargs() - actual = gen.actual_kwargs() - self._get_custom_queries_from( - cq_key2args, - nominal, - actual, - ) - for cq_key, cq in cqs.items(): - self.print( - "{0}; providing the following values: {1}", - cq["query"], - cq_key2args[cq_key], - ) - - def _get_custom_queries_from( - self, out: dict[str, Any], nominal: Any, actual: Any - ) -> None: - if isinstance(nominal, str): - src_stat_groups = self.SRC_STAT_RE.search(nominal) - # Do we have a SRC_STAT reference? - if src_stat_groups: - # Get its name - cq_key = src_stat_groups.group(1) - # Are we pulling a specific part of this result? - sub = src_stat_groups.group(3) - if sub: - actual = {sub: actual} - else: - out[cq_key] = actual - elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): - for i in range(min(len(nominal), len(actual))): - self._get_custom_queries_from(out, nominal[i], actual[i]) - elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): - for k, v in nominal.items(): - if k in actual: - self._get_custom_queries_from(out, v, actual[k]) - - def _get_aggregate_query( - self, gens: list[Generator], table_name: str - ) -> str | None: - clauses = [ - f'{q["clause"]} AS {n}' - for gen in gens - for n, q in or_default(gen.select_aggregate_clauses(), {}).items() - ] - if not clauses: - return None - return f"SELECT {', '.join(clauses)} FROM {table_name}" - - def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: - """ - Print the select aggregate query and all the values it gets in this case. - - This is not the entire query that will be executed, but only the part of it - that is required by a certain generator. - :param table_name: The table name. - :param gen: The generator to limit the aggregate query to. - """ - sacs = gen.select_aggregate_clauses() - if not sacs: - return - kwa = gen.actual_kwargs() - vals = [] - src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} - for n in sacs.keys(): - src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' - if src_stat in src_stat2kwarg: - ak = src_stat2kwarg[src_stat] - if ak in kwa: - vals.append(kwa[ak]) - else: - logger.warning( - "actual_kwargs for %s does not report %s", gen.name(), ak - ) - else: - logger.warning( - 'nominal_kwargs for %s does not have a value SRC_STATS["auto__%s"]["results"][0]["%s"]', - gen.name(), - table_name, - n, - ) - select_q = self._get_aggregate_query([gen], table_name) - self.print("{0}; providing the following values: {1}", select_q, vals) - - def _get_column_data( - self, count: int, to_str: Callable[[Any], str] = repr - ) -> list[list[str]]: - columns = self._get_column_names() - columns_string = ", ".join(columns) - pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) - with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - f"SELECT {columns_string} FROM {self.table_name()} WHERE {pred} ORDER BY RANDOM() LIMIT {count}" - ) - ) - return [[to_str(x) for x in xs] for xs in result.all()] - - def do_propose(self, _arg: str) -> None: - """ - Display a list of possible generators for this column. - - They will be listed in order of fit, the most likely matches first. - The results can be compared (against a sample of the real data in - the column and against each other) with the 'compare' command. - """ - limit = 5 - gens = self._get_generator_proposals() - sample = self._get_column_data(limit) - if sample: - rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] - self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) - else: - self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) - if not gens: - self.print(self.PROPOSE_NOTHING) - for index, gen in enumerate(gens): - fit = gen.fit(-1) - if fit == -1: - fit_s = "(no fit)" - elif fit < 100: - fit_s = f"(fit: {fit:.3g})" - else: - fit_s = f"(fit: {fit:.0f})" - self.print( - self.PROPOSE_GENERATOR_SAMPLE_TEXT, - index=index + 1, - name=gen.name(), - fit=fit_s, - sample="; ".join(map(repr, gen.generate_data(limit))), - ) - - def do_p(self, arg: str) -> None: - """Synonym for propose.""" - self.do_propose(arg) - - def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: - """Find a generator by name from the list of proposals.""" - for gen in self._get_generator_proposals(): - if gen.name() == gen_name: - return gen - return None - - def do_set(self, arg: str) -> None: - """Set one of the proposals as a generator.""" - if arg.isdigit() and not self._generators_valid(): - self.print("Please run 'propose' before 'set '") - return - gens = self._get_generator_proposals() - new_gen: Generator | None - if arg.isdigit(): - index = int(arg) - if index < 1: - self.print("set's integer argument must be at least 1") - return - if len(gens) < index: - self.print( - "There are currently only {0} generators proposed, please select one of them.", - len(gens), - ) - return - new_gen = gens[index - 1] - else: - new_gen = self.get_proposed_generator_by_name(arg) - if new_gen is None: - self.print("'{0}' is not an appropriate generator for this column", arg) - return - self.set_generator(new_gen) - self._go_next() - - def set_generator(self, gen: Generator | None) -> None: - """Set the current column's generator.""" - (table, gen_info) = self._get_table_and_generator() - if table is None: - self.print("Error: no table") - return - if gen_info is None: - self.print("Error: no column") - return - gen_info.gen = gen - - def do_s(self, arg: str) -> None: - """Synonym for set.""" - self.do_set(arg) - - def do_unset(self, _arg: str) -> None: - """Remove any generator set for this column.""" - self.set_generator(None) - self._go_next() - - def do_merge(self, arg: str) -> None: - """ - Add this column(s) to the specified column(s). - - After this, one generator will cover them all. - """ - cols = arg.split() - if not cols: - self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - self.print(self.ERROR_NO_SUCH_TABLE) - return - cols_available = functools.reduce( - lambda x, y: x | y, - [frozenset(gen.columns) for gen in table_entry.new_generators], - ) - cols_to_merge = frozenset(cols) - unknown_cols = cols_to_merge - cols_available - if unknown_cols: - for uc in unknown_cols: - self.print(self.ERROR_NO_SUCH_COLUMN, uc) - return - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - stated_current_columns = cols_to_merge & current_columns - if stated_current_columns: - for c in stated_current_columns: - self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) - return - # Remove cols_to_merge from each generator - new_new_generators: list[GeneratorInfo] = [] - for gen in table_entry.new_generators: - if gen is gen_info: - # Add columns to this generator - self.generator_index = len(new_new_generators) - new_new_generators.append( - GeneratorInfo( - columns=gen.columns + cols, - gen=None, - ) - ) - else: - # Remove columns if applicable - new_columns = [c for c in gen.columns if c not in cols_to_merge] - is_changed = len(new_columns) != len(gen.columns) - if new_columns: - # We have not removed this generator completely - new_new_generators.append( - GeneratorInfo( - columns=new_columns, - gen=None if is_changed else gen.gen, - ) - ) - table_entry.new_generators = new_new_generators - self.set_prompt() - - def complete_merge( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Complete column names.""" - last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - return [] - return [ - column - for i, gen in enumerate(table_entry.new_generators) - if i != self.generator_index - for column in gen.columns - if column.startswith(last_arg) - ] - - def do_unmerge(self, arg: str) -> None: - """Remove this column(s) from this generator, make them a separate generator.""" - cols = arg.split() - if not cols: - self.print("Error: merge requires a column argument") - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - self.print(self.ERROR_NO_SUCH_TABLE) - return - gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - cols_to_unmerge = frozenset(cols) - unknown_cols = cols_to_unmerge - current_columns - if unknown_cols: - for uc in unknown_cols: - self.print(self.ERROR_NO_SUCH_COLUMN, uc) - return - stated_unmerged_columns = cols_to_unmerge - current_columns - if stated_unmerged_columns: - for c in stated_unmerged_columns: - self.print(self.ERROR_COLUMN_ALREADY_UNMERGED, c) - return - if cols_to_unmerge == current_columns: - self.print(self.ERROR_CANNOT_UNMERGE_ALL) - return - # Remove unmerged columns - for um in cols_to_unmerge: - gen_info.columns.remove(um) - # The existing generator will not work - gen_info.gen = None - # And put them into a new (empty) generator - table_entry.new_generators.insert( - self.generator_index + 1, - GeneratorInfo( - columns=cols, - gen=None, - ), - ) - self.set_prompt() - - def complete_unmerge( - self, text: str, _line: str, _begidx: int, _endidx: int - ) -> list[str]: - """Complete column names to unmerge.""" - last_arg = text.split()[-1] - table_entry: GeneratorCmdTableEntry | None = self.get_table() - if table_entry is None: - return [] - return [ - column - for column in table_entry.new_generators[self.generator_index].columns - if column.startswith(last_arg) - ] - - -def update_config_generators( - src_dsn: str, - src_schema: str | None, - metadata: MetaData, - config: MutableMapping[str, Any], - spec_path: Path | None, -) -> Mapping[str, Any]: - """ - Update configuration with the specification from a CSV file. - - The specification is a headerless CSV file with columns: Table name, - Column name (or space-separated list of column names), Generator - name required, Second choice generator name, Third choice generator - name, etcetera. - :param src_dsn: Address of the source database - :param src_schema: Name of the source database schema to read from - :param metadata: SQLAlchemy representation of the source database - :param config: Existing configuration (will be destructively updated) - :param spec_path: The path of the CSV file containing the specification - :return: Updated configuration. - """ - with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: - if spec_path is None: - gc.cmdloop() - return gc.config - spec = spec_path.open() - line_no = 0 - for line in csv.reader(spec): - line_no += 1 - if line: - if len(line) != 3: - logger.error( - "line %d of file %s does not have three values", - line_no, - spec_path, - ) - if gc.go_to(f"{line[0]}.{line[1]}"): - gc.do_set(line[2]) - gc.do_quit("yes") - return gc.config diff --git a/datafaker/interactive/__init__.py b/datafaker/interactive/__init__.py new file mode 100644 index 0000000..72ec00e --- /dev/null +++ b/datafaker/interactive/__init__.py @@ -0,0 +1,95 @@ +"""Interactive configuration commands.""" +import csv +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Any + +from sqlalchemy import MetaData + +from datafaker.interactive.table import TableCmd +from datafaker.interactive.generators import GeneratorCmd +from datafaker.interactive.missingness import MissingnessCmd +from datafaker.utils import logger + +# Monkey patch pyreadline3 v3.5 so that it works with Python 3.13 +# Windows users can install pyreadline3 to get tab completion working. +# See https://github.com/pyreadline3/pyreadline3/issues/37 +try: + import readline + + if not hasattr(readline, "backend"): + setattr(readline, "backend", "readline") +except: + pass + + +def update_config_tables( + src_dsn: str, src_schema: str | None, metadata: MetaData, config: MutableMapping +) -> Mapping[str, Any]: + """Ask the user to specify what should happen to each table.""" + with TableCmd(src_dsn, src_schema, metadata, config) as tc: + tc.cmdloop() + return tc.config + + +def update_missingness( + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], +) -> Mapping[str, Any]: + """ + Ask the user to update the missingness information in ``config.yaml``. + + :param src_dsn: The connection string for the source database. + :param src_schema: The name of the source database schema (or None + for the default). + :param metadata: The SQLAlchemy metadata object from ``orm.yaml``. + :param config: The starting configuration, + :return: The updated configuration. + """ + with MissingnessCmd(src_dsn, src_schema, metadata, config) as mc: + mc.cmdloop() + return mc.config + + +def update_config_generators( + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + spec_path: Path | None, +) -> Mapping[str, Any]: + """ + Update configuration with the specification from a CSV file. + + The specification is a headerless CSV file with columns: Table name, + Column name (or space-separated list of column names), Generator + name required, Second choice generator name, Third choice generator + name, etcetera. + :param src_dsn: Address of the source database + :param src_schema: Name of the source database schema to read from + :param metadata: SQLAlchemy representation of the source database + :param config: Existing configuration (will be destructively updated) + :param spec_path: The path of the CSV file containing the specification + :return: Updated configuration. + """ + with GeneratorCmd(src_dsn, src_schema, metadata, config) as gc: + if spec_path is None: + gc.cmdloop() + return gc.config + spec = spec_path.open() + line_no = 0 + for line in csv.reader(spec): + line_no += 1 + if line: + if len(line) != 3: + logger.error( + "line %d of file %s does not have three values", + line_no, + spec_path, + ) + if gc.go_to(f"{line[0]}.{line[1]}"): + gc.do_set(line[2]) + gc.do_quit("yes") + return gc.config diff --git a/datafaker/interactive/base.py b/datafaker/interactive/base.py new file mode 100644 index 0000000..51793fe --- /dev/null +++ b/datafaker/interactive/base.py @@ -0,0 +1,404 @@ +"""Base configuration command shells.""" +import cmd +from abc import ABC, abstractmethod +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass +from enum import Enum +from types import TracebackType +from typing import Any, Optional, Type + +import sqlalchemy +from prettytable import PrettyTable +from sqlalchemy import Engine, ForeignKey, MetaData, Table +from typing_extensions import Self + +from datafaker.utils import ( + T, + create_db_engine, + fk_refers_to_ignored_table, + get_sync_engine, +) + + +def or_default(v: T | None, d: T) -> T: + """Return v if it isn't None, otherwise d.""" + return d if v is None else v + + +class TableType(Enum): + """Types of table to be configured.""" + + GENERATE = "generate" + IGNORE = "ignore" + VOCABULARY = "vocabulary" + PRIVATE = "private" + EMPTY = "empty" + + +TYPE_LETTER = { + TableType.GENERATE: "G", + TableType.IGNORE: "I", + TableType.VOCABULARY: "V", + TableType.PRIVATE: "P", + TableType.EMPTY: "e", +} + +TYPE_PROMPT = { + TableType.GENERATE: "(table: {}) ", + TableType.IGNORE: "(table: {} (ignore)) ", + TableType.VOCABULARY: "(table: {} (vocab)) ", + TableType.PRIVATE: "(table: {} (private)) ", + TableType.EMPTY: "(table: {} (empty))", +} + + +@dataclass +class TableEntry: + """Base class for table entries for interactive commands.""" + + name: str # name of the table + + +class AskSaveCmd(cmd.Cmd): + """Interactive shell for whether to save and quit.""" + + intro = "Do you want to save this configuration?" + prompt = "(yes/no/cancel) " + file = None + + def __init__(self) -> None: + """Initialise a save command.""" + super().__init__() + self.result = "" + + def do_yes(self, _arg: str) -> bool: + """Save the new config.yaml.""" + self.result = "yes" + return True + + def do_no(self, _arg: str) -> bool: + """Exit without saving.""" + self.result = "no" + return True + + def do_cancel(self, _arg: str) -> bool: + """Do not exit.""" + self.result = "cancel" + return True + + +def fk_column_name(fk: ForeignKey) -> str: + """Display name for a foreign key.""" + if fk_refers_to_ignored_table(fk): + return f"{fk.target_fullname} (ignored)" + return str(fk.target_fullname) + + +class DbCmd(ABC, cmd.Cmd): + """Base class for interactive configuration commands.""" + + INFO_NO_MORE_TABLES = "There are no more tables" + ERROR_ALREADY_AT_START = "Error: Already at the start" + ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" + ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" + ROW_COUNT_MSG = "Total row count: {}" + + @abstractmethod + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableEntry | None: + """ + Make a table entry suitable for this interactive command. + + :param name: The name of the table to make an entry for. + :param table_config: The part of the ``config.yaml`` referring to this table. + :return: The table entry or None if this table should not be interacted with. + """ + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ): + """Initialise a DbCmd.""" + super().__init__() + self.config: MutableMapping[str, Any] = config + self.metadata = metadata + self._table_entries: list[TableEntry] = [] + tables_config: MutableMapping = config.get("tables", {}) + if not isinstance(tables_config, MutableMapping): + tables_config = {} + for name in metadata.tables.keys(): + table_config = tables_config.get(name, {}) + if not isinstance(table_config, MutableMapping): + table_config = {} + entry = self.make_table_entry(name, table_config) + if entry is not None: + self._table_entries.append(entry) + self.table_index = 0 + self.engine = create_db_engine(src_dsn, schema_name=src_schema) + + @property + def sync_engine(self) -> Engine: + """Get the synchronous version of the engine.""" + return get_sync_engine(self.engine) + + def __enter__(self) -> Self: + """Enter a ``with`` statement.""" + return self + + def __exit__( + self, + _exc_type: Optional[Type[BaseException]], + _exc_val: Optional[BaseException], + _exc_tb: Optional[TracebackType], + ) -> None: + """Dispose of this object.""" + self.engine.dispose() + + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Print text, formatted with positional and keyword arguments.""" + print(text.format(*args, **kwargs)) + + def print_table( + self, headings: Sequence[str], rows: Sequence[Sequence[Any]] + ) -> None: + """ + Print a table. + + :param headings: List of headings for the table. + :param rows: List of rows of values. + """ + output = PrettyTable() + output.field_names = headings + for row in rows: + # Hopefully PrettyTable will accept Sequence in the future, not list + output.add_row(list(row)) + print(output) + + def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: + """ + Print a table. + + :param columns: Dict of column names to the values in the column. + """ + output = PrettyTable() + row_count = max([len(col) for col in columns.values()]) + for field_name, data in columns.items(): + output.add_column(field_name, list(data) + [None] * (row_count - len(data))) + print(output) + + def print_results(self, result: sqlalchemy.CursorResult) -> None: + """Print the rows resulting from a database query.""" + self.print_table(list(result.keys()), [list(row) for row in result.all()]) + + def ask_save(self) -> str: + """ + Ask the user if they want to save. + + :return: ``yes``, ``no`` or ``cancel``. + """ + ask = AskSaveCmd() + ask.cmdloop() + return ask.result + + @abstractmethod + def set_prompt(self) -> None: + """Set the prompt according to the current state.""" + ... + + def _set_table_index(self, index: int) -> bool: + """ + Move to a different table. + + :param index: Index of the table to move to. + :return: True if there is a table with such an index to move to. + """ + if 0 <= index < len(self._table_entries): + self.table_index = index + self.set_prompt() + return True + return False + + def next_table(self, report: str = "No more tables") -> bool: + """ + Move to the next table. + + :param report: The text to print if there is no next table. + :return: True if there is another table to move to. + """ + if not self._set_table_index(self.table_index + 1): + self.print(report) + return False + return True + + def table_name(self) -> str: + """Get the name of the current table.""" + return str(self._table_entries[self.table_index].name) + + def table_metadata(self) -> Table: + """Get the metadata of the current table.""" + return self.metadata.tables[self.table_name()] + + def _get_column_names(self) -> list[str]: + """Get the names of the current columns.""" + return [col.name for col in self.table_metadata().columns] + + def report_columns(self) -> None: + """Print information about the current columns.""" + self.print_table( + ["name", "type", "primary", "nullable", "foreign key"], + [ + [ + name, + str(col.type), + col.primary_key, + col.nullable, + ", ".join([fk_column_name(fk) for fk in col.foreign_keys]), + ] + for name, col in self.table_metadata().columns.items() + ], + ) + + def get_table_config(self, table_name: str) -> MutableMapping[str, Any]: + """Get the configuration of the named table.""" + ts = self.config.get("tables", None) + if not isinstance(ts, MutableMapping): + return {} + t = ts.get(table_name) + return t if isinstance(t, MutableMapping) else {} + + def set_table_config( + self, table_name: str, config: MutableMapping[str, Any] + ) -> None: + """Set the configuration of the named table.""" + ts = self.config.get("tables", None) + if not isinstance(ts, MutableMapping): + self.config["tables"] = {table_name: config} + return + ts[table_name] = config + + def _remove_prefix_src_stats(self, prefix: str) -> list[MutableMapping[str, Any]]: + """Remove all source stats with the given prefix from the configuration.""" + src_stats = self.config.get("src-stats", []) + new_src_stats = [] + for stat in src_stats: + if not stat.get("name", "").startswith(prefix): + new_src_stats.append(stat) + self.config["src-stats"] = new_src_stats + return new_src_stats + + def get_nonnull_columns(self, table_name: str) -> list[str]: + """Get the names of the nullable columns in the named table.""" + metadata_table = self.metadata.tables[table_name] + return [ + str(name) + for name, column in metadata_table.columns.items() + if column.nullable + ] + + def find_entry_index_by_table_name(self, table_name: str) -> int | None: + """Get the index of the table entry of the named table.""" + return next( + ( + i + for i, entry in enumerate(self._table_entries) + if entry.name == table_name + ), + None, + ) + + def _find_entry_by_table_name(self, table_name: str) -> TableEntry | None: + """Get the table entry of the named table.""" + for e in self._table_entries: + if e.name == table_name: + return e + return None + + def do_counts(self, _arg: str) -> None: + """Report the column names with the counts of nulls in them.""" + if len(self._table_entries) <= self.table_index: + return + table_name = self.table_name() + nonnull_columns = self.get_nonnull_columns(table_name) + colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns] + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + ) + ).first() + if result is None: + self.print("Could not count rows in table {0}", table_name) + return + row_count = result.row_count + self.print(self.ROW_COUNT_MSG, row_count) + self.print_table( + ["Column", "NULL count"], + [ + [name, row_count - count] + for name, count in result._mapping.items() + if name != "row_count" + ], + ) + + def do_select(self, arg: str) -> None: + """Run a select query over the database and show the first 50 results.""" + max_select_rows = 50 + with self.sync_engine.connect() as connection: + try: + result = connection.execute(sqlalchemy.text("SELECT " + arg)) + except sqlalchemy.exc.DatabaseError as exc: + self.print("Failed to execute: {}", exc) + return + row_count = result.rowcount + self.print(self.ROW_COUNT_MSG, row_count) + if 50 < row_count: + self.print("Showing the first {} rows", max_select_rows) + fields = list(result.keys()) + rows = result.fetchmany(max_select_rows) + self.print_table(fields, rows) + + def do_peek(self, arg: str) -> None: + """ + View some data from the current table. + + Use 'peek col1 col2 col3' to see a sample of values from + columns col1, col2 and col3 in the current table. + Use 'peek' to see a sample of the current column(s). + Rows that are enitrely null are suppressed. + """ + max_peek_rows = 25 + if len(self._table_entries) <= self.table_index: + return + table_name = self.table_name() + col_names = arg.split() + if not col_names: + col_names = self._get_column_names() + nonnulls = [cn + " IS NOT NULL" for cn in col_names] + with self.sync_engine.connect() as connection: + cols = ",".join(col_names) + where = "WHERE" if nonnulls else "" + nonnull = " OR ".join(nonnulls) + query = sqlalchemy.text( + f"SELECT {cols} FROM {table_name} {where} {nonnull}" + f" ORDER BY RANDOM() LIMIT {max_peek_rows}" + ) + try: + result = connection.execute(query) + except Exception as exc: + self.print(f'SQL query "{query}" caused exception {exc}') + return + self.print_table(list(result.keys()), result.fetchmany(max_peek_rows)) + + def complete_peek( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Completions for the ``peek`` command.""" + if len(self._table_entries) <= self.table_index: + return [] + return [ + col for col in self.table_metadata().columns.keys() if col.startswith(text) + ] diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py new file mode 100644 index 0000000..4c87c6b --- /dev/null +++ b/datafaker/interactive/generators.py @@ -0,0 +1,980 @@ +"""Generator configuration shell.""" +from dataclasses import dataclass +from collections.abc import Mapping, Sequence, Iterable, MutableMapping +import functools +import re +from typing import Any, Optional, cast, Callable + +import sqlalchemy +from sqlalchemy import Column, MetaData + +from datafaker.generators import everything_factory +from datafaker.generators.base import Generator, PredefinedGenerator +from datafaker.interactive.base import TableEntry, DbCmd, fk_column_name, or_default +from datafaker.utils import logger, table_is_private, primary_private_fks + +@dataclass +class GeneratorInfo: + """A generator and the columns it assigns to.""" + + columns: list[str] + gen: Generator | None + + +@dataclass +class GeneratorCmdTableEntry(TableEntry): + """ + List of generators set for a table. + + Includes the original setting and the currently configured + generators. + """ + + old_generators: list[GeneratorInfo] + new_generators: list[GeneratorInfo] + + +class GeneratorCmd(DbCmd): + """Interactive command shell for setting generators.""" + + intro = "Interactive generator configuration. Type ? for help.\n" + doc_leader = """Use command 'propose' for a list of generators applicable to the +current column, then command 'compare' to see how these perform +against the source data, then command 'set' to choose your favourite. +Use 'unset' to remove the column's generator. Use commands 'next' and +'previous' to change which column we are examining. Use 'info' +for useful information about the current column. Use 'tables' and +'list' to see available tables and columns. Use 'columns' to see +information about the columns in the current table. Use 'peek', +'count' or 'select' to fetch data from the source database. Use +'quit' to exit this program.""" + prompt = "(generatorconf) " + file = None + + PROPOSE_SOURCE_SAMPLE_TEXT = "Sample of actual source data: {0}..." + PROPOSE_SOURCE_EMPTY_TEXT = "Source database has no data in this column." + PROPOSE_GENERATOR_SAMPLE_TEXT = "{index}. {name}: {fit} {sample} ..." + PRIMARY_PRIVATE_TEXT = "Primary Private" + SECONDARY_PRIVATE_TEXT = "Secondary Private on columns {0}" + NOT_PRIVATE_TEXT = "Not private" + ERROR_NO_SUCH_TABLE = "No such (non-vocabulary, non-ignored) table name {0}" + ERROR_NO_SUCH_COLUMN = "No such column {0} in this table" + ERROR_COLUMN_ALREADY_MERGED = "Column {0} is already merged" + ERROR_COLUMN_ALREADY_UNMERGED = "Column {0} is not merged" + ERROR_CANNOT_UNMERGE_ALL = "You cannot unmerge all the generator's columns" + PROPOSE_NOTHING = "No proposed generators, sorry." + + SRC_STAT_RE = re.compile( + r'\bSRC_STATS\["([^"]+)"\](\["results"\]\[0\]\["([^"]+)"\])?' + ) + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> GeneratorCmdTableEntry | None: + """ + Make a table entry. + + :param table_name: The name of the table. + :param table: The portion of the ``config.yaml`` file describing this table. + :return: The newly constructed table entry, or None if this table is to be ignored. + """ + if table_config.get("ignore", False): + return None + if table_config.get("vocabulary_table", False): + return None + if table_config.get("num_rows_per_pass", 1) == 0: + return None + metadata_table = self.metadata.tables[table_name] + columns = [str(colname) for colname in metadata_table.columns.keys()] + column_set = frozenset(columns) + columns_assigned_so_far: set[str] = set() + + new_generator_infos: list[GeneratorInfo] = [] + old_generator_infos: list[GeneratorInfo] = [] + for rg in table_config.get("row_generators", []): + gen_name = rg.get("name", None) + if gen_name: + ca = rg.get("columns_assigned", []) + collist: list[str] = ( + [ca] if isinstance(ca, str) else [str(c) for c in ca] + ) + colset: set[str] = set(collist) + for unknown in colset - column_set: + logger.warning( + "table '%s' has '%s' assigned to column '%s' which is not in this table", + table_name, + gen_name, + unknown, + ) + for mult in columns_assigned_so_far & colset: + logger.warning( + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, + ) + actual_collist = [c for c in collist if c in columns] + if actual_collist: + gen = PredefinedGenerator(table_name, rg, self.config) + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) + old_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=gen, + ) + ) + columns_assigned_so_far |= colset + for colname in columns: + if colname not in columns_assigned_so_far: + new_generator_infos.append( + GeneratorInfo( + columns=[colname], + gen=None, + ) + ) + if len(new_generator_infos) == 0: + return None + return GeneratorCmdTableEntry( + name=table_name, + old_generators=old_generator_infos, + new_generators=new_generator_infos, + ) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ) -> None: + """ + Initialise a ``GeneratorCmd``. + + :param src_dsn: connection address for source database + :param src_schema: database schema name + :param metadata: SQLAlchemy metadata for the source database + :param config: Configuration loaded from ``config.yaml`` + """ + super().__init__(src_dsn, src_schema, metadata, config) + self.generators: list[Generator] | None = None + self.generator_index = 0 + self.generators_valid_columns: Optional[tuple[int, list[str]]] = None + self.set_prompt() + + @property + def table_entries(self) -> list[GeneratorCmdTableEntry]: + """Get the talbe entries, cast to ``GeneratorCmdTableEntry``.""" + return cast(list[GeneratorCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name( + self, table_name: str + ) -> GeneratorCmdTableEntry | None: + """ + Find the table entry by name. + + :param table_name: The name of the table to find. + :return: The table entry, or None if no such table name exists. + """ + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(GeneratorCmdTableEntry, entry) + + def _set_table_index(self, index: int) -> bool: + """ + Move to a new table. + + :param index: table index to move to. + """ + ret = super()._set_table_index(index) + if ret: + self.generator_index = 0 + self.set_prompt() + return ret + + def _previous_table(self) -> bool: + """ + Move to the table before the current one. + + :return: True if there is a previous table to go to. + """ + ret = self._set_table_index(self.table_index - 1) + if ret: + table = self.get_table() + if table is None: + self.print( + "Internal error! table {0} does not have any generators!", + self.table_index, + ) + return False + self.generator_index = len(table.new_generators) - 1 + else: + self.print(self.ERROR_ALREADY_AT_START) + return ret + + def get_table(self) -> GeneratorCmdTableEntry | None: + """Get the current table entry.""" + if self.table_index < len(self.table_entries): + return self.table_entries[self.table_index] + return None + + def _get_table_and_generator(self) -> tuple[str | None, GeneratorInfo | None]: + """Get a pair; the table name then the generator information.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + if self.generator_index < len(entry.new_generators): + return (entry.name, entry.new_generators[self.generator_index]) + return (entry.name, None) + return (None, None) + + def _get_column_names(self) -> list[str]: + """Get the (unqualified) names for all the current columns.""" + (_, generator_info) = self._get_table_and_generator() + return generator_info.columns if generator_info else [] + + def _column_metadata(self) -> list[Column]: + """Get the metadata for all the current columns.""" + table = self.table_metadata() + if table is None: + return [] + return [table.columns[name] for name in self._get_column_names()] + + def set_prompt(self) -> None: + """Set the prompt according to the current table, column and generator.""" + (table_name, gen_info) = self._get_table_and_generator() + if table_name is None: + self.prompt = "(generators) " + return + if gen_info is None: + self.prompt = f"({table_name}) " + return + table = self.table_metadata() + columns = [ + c + "[pk]" if table.columns[c].primary_key else c for c in gen_info.columns + ] + gen = f" ({gen_info.gen.name()})" if gen_info.gen else "" + self.prompt = f"({table_name}.{','.join(columns)}{gen}) " + + def _remove_auto_src_stats(self) -> list[MutableMapping[str, Any]]: + """ + Remove all automatic source stats. + + We assume every source stats query whose name begins with ``auto__` + :return: The new ``src_stats`` configuration. + """ + return self._remove_prefix_src_stats("auto__") + + def _copy_entries(self) -> None: + """Set generator and query information in the configuration.""" + src_stats = self._remove_auto_src_stats() + for entry in self.table_entries: + rgs = [] + new_gens: list[Generator] = [] + for generator in entry.new_generators: + if generator.gen is not None: + new_gens.append(generator.gen) + cqs = generator.gen.custom_queries() + for cq_key, cq in cqs.items(): + src_stats.append( + { + "name": cq_key, + "query": cq["query"], + "comments": [cq["comment"]] + if "comment" in cq and cq["comment"] + else [], + } + ) + rg: dict[str, Any] = { + "name": generator.gen.function_name(), + "columns_assigned": generator.columns, + } + kwn = generator.gen.nominal_kwargs() + if kwn: + rg["kwargs"] = kwn + rgs.append(rg) + aq = self._get_aggregate_query(new_gens, entry.name) + if aq: + src_stats.append( + { + "name": f"auto__{entry.name}", + "query": aq, + "comments": [ + q["comment"] + for gen in new_gens + for q in gen.select_aggregate_clauses().values() + if "comment" in q and q["comment"] is not None + ], + } + ) + table_config = self.get_table_config(entry.name) + if rgs: + table_config["row_generators"] = rgs + elif "row_generators" in table_config: + del table_config["row_generators"] + self.set_table_config(entry.name, table_config) + self.config["src-stats"] = src_stats + + def _find_old_generator( + self, entry: GeneratorCmdTableEntry, columns: Iterable[str] + ) -> Generator | None: + """Find any generator that previously assigned to these exact same columns.""" + fc = frozenset(columns) + for gen in entry.old_generators: + if frozenset(gen.columns) == fc: + return gen.gen + return None + + def do_quit(self, arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + header_shown = False + g_entry = cast(GeneratorCmdTableEntry, entry) + for gen in g_entry.new_generators: + old_gen = self._find_old_generator(g_entry, gen.columns) + new_gen = None if gen is None else gen.gen + if old_gen != new_gen: + if not header_shown: + header_shown = True + self.print("Table {0}:", entry.name) + count += 1 + self.print( + "...changing {0} from {1} to {2}", + ", ".join(gen.columns), + old_gen.name() if old_gen else "nothing", + gen.gen.name() if gen.gen else "nothing", + ) + if count == 0: + self.print("You have made no changes.") + if arg in {"yes", "no"}: + reply = arg + else: + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables.""" + for t_entry in self.table_entries: + entry = cast(GeneratorCmdTableEntry, t_entry) + gen_count = len(entry.new_generators) + how_many = "one generator" if gen_count == 1 else f"{gen_count} generators" + self.print("{0} ({1})", entry.name, how_many) + + def do_list(self, _arg: str) -> None: + """List the generators in the current table.""" + if len(self.table_entries) <= self.table_index: + self.print("Error: no table {0}", self.table_index) + return + g_entry = cast(GeneratorCmdTableEntry, self.table_entries[self.table_index]) + table = self.table_metadata() + for gen in g_entry.new_generators: + old_gen = self._find_old_generator(g_entry, gen.columns) + old = "" if old_gen is None else old_gen.name() + if old_gen == gen.gen: + becomes = "" + if old == "": + old = "(not set)" + elif gen.gen is None: + becomes = "(delete)" + else: + becomes = f"->{gen.gen.name()}" + primary = "" + if len(gen.columns) == 1 and table.columns[gen.columns[0]].primary_key: + primary = "[primary-key]" + self.print("{0}{1}{2} {3}", old, becomes, primary, gen.columns) + + def do_columns(self, _arg: str) -> None: + """Report the column names and metadata.""" + self.report_columns() + + def do_info(self, _arg: str) -> None: + """Show information about the current column.""" + for cm in self._column_metadata(): + self.print( + "Column {0} in table {1} has type {2} ({3}).", + cm.name, + cm.table.name, + str(cm.type), + "nullable" if cm.nullable else "not nullable", + ) + if cm.primary_key: + self.print( + "It is a primary key, which usually does not" + " need a generator (it will auto-increment)" + ) + if cm.foreign_keys: + fk_names = [fk_column_name(fk) for fk in cm.foreign_keys] + self.print( + "It is a foreign key referencing column {0}", ", ".join(fk_names) + ) + if len(fk_names) == 1 and not cm.primary_key: + self.print( + "You do not need a generator if you just want" + " a uniform choice over the referenced table's rows" + ) + + def _get_table_index(self, table_name: str) -> int | None: + """Get the index of the named table in the table entries list.""" + for n, entry in enumerate(self.table_entries): + if entry.name == table_name: + return n + return None + + def _get_generator_index(self, table_index: int, column_name: str) -> int | None: + """ + Get the index number of a column within the list of generators in this table. + + :param table_index: The index of the table in which to search. + :param column_name: The name of the column to search for. + :return: The index in the ``new_generators`` attribute of the table entry + containing the specified column, or None if this does not exist. + """ + entry = self.table_entries[table_index] + for n, gen in enumerate(entry.new_generators): + if column_name in gen.columns: + return n + return None + + def go_to(self, target: str) -> bool: + """ + Go to a particular column. + + :return: True on success. + """ + parts = target.split(".", 1) + table_index = self._get_table_index(parts[0]) + if table_index is None: + if len(parts) == 1: + gen_index = self._get_generator_index(self.table_index, parts[0]) + if gen_index is not None: + self.generator_index = gen_index + self.set_prompt() + return True + self.print(self.ERROR_NO_SUCH_TABLE_OR_COLUMN, parts[0]) + return False + gen_index = None + if 1 < len(parts) and parts[1]: + gen_index = self._get_generator_index(table_index, parts[1]) + if gen_index is None: + self.print("we cannot set the generator for column {0}", parts[1]) + return False + self._set_table_index(table_index) + if gen_index is not None: + self.generator_index = gen_index + self.set_prompt() + return True + + def do_next(self, arg: str) -> None: + """ + Go to the next generator. or a specified generator. + + Go to a named table: 'next tablename', + go to a column: 'next tablename.columnname', + or go to a column within this table: 'next columnname'. + """ + if arg: + self.go_to(arg) + else: + self._go_next() + + def do_n(self, arg: str) -> None: + """Go to the next generator, or a specified generator.""" + self.do_next(arg) + + def complete_n(self, text: str, line: str, begidx: int, endidx: int) -> list[str]: + """Complete the ``n`` command's arguments.""" + return self.complete_next(text, line, begidx, endidx) + + def _go_next(self) -> None: + """Go to the next column.""" + table = self.get_table() + if table is None: + self.print("No more tables") + return + next_gi = self.generator_index + 1 + if next_gi == len(table.new_generators): + self.next_table(self.INFO_NO_MORE_TABLES) + return + self.generator_index = next_gi + self.set_prompt() + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Completions for the arguments of the ``next`` command.""" + parts = text.split(".", 1) + first_part = parts[0] + if 1 < len(parts): + column_name = parts[1] + table_index = self._get_table_index(first_part) + if table_index is None: + return [] + table_entry = self.table_entries[table_index] + return [ + f"{first_part}.{column}" + for gen in table_entry.new_generators + for column in gen.columns + if column.startswith(column_name) + ] + table_names = [ + entry.name + for entry in self.table_entries + if entry.name.startswith(first_part) + ] + if first_part in table_names: + table_names.append(f"{first_part}.") + current_table = self.get_table() + if current_table: + column_names = [ + col + for gen in current_table.new_generators + for col in gen.columns + if col.startswith(first_part) + ] + else: + column_names = [] + return table_names + column_names + + def do_previous(self, _arg: str) -> None: + """Go to the previous generator.""" + if self.generator_index == 0: + self._previous_table() + else: + self.generator_index -= 1 + self.set_prompt() + + def do_b(self, arg: str) -> None: + """Synonym for previous.""" + self.do_previous(arg) + + def _generators_valid(self) -> bool: + """Test if ``self.generators`` is still correct for the current columns.""" + return self.generators_valid_columns == ( + self.table_index, + self._get_column_names(), + ) + + def _get_generator_proposals(self) -> list[Generator]: + """Get a list of acceptable generators, sorted by decreasing fit to the actual data.""" + if not self._generators_valid(): + self.generators = None + if self.generators is None: + columns = self._column_metadata() + gens = everything_factory().get_generators(columns, self.sync_engine) + sorted_gens = sorted(gens, key=lambda g: g.fit(9999)) + self.generators = sorted_gens + self.generators_valid_columns = ( + self.table_index, + self._get_column_names().copy(), + ) + return self.generators + + def _print_privacy(self) -> None: + """Print the privacy status of the current table.""" + table = self.table_metadata() + if table is None: + return + if table_is_private(self.config, table.name): + self.print(self.PRIMARY_PRIVATE_TEXT) + return + pfks = primary_private_fks(self.config, table) + if not pfks: + self.print(self.NOT_PRIVATE_TEXT) + return + self.print(self.SECONDARY_PRIVATE_TEXT, pfks) + + def do_compare(self, arg: str) -> None: + """ + Compare the real data with some generators. + + 'compare': just look at some source data from this column. + 'compare 5 6 10': compare a sample of the source data with a sample + from generators 5, 6 and 10. You can find out which numbers + correspond to which generators using the 'propose' command. + """ + self._print_privacy() + args = arg.split() + limit = 20 + comparison = { + "source": [ + x[0] if len(x) == 1 else ", ".join(x) + for x in self._get_column_data(limit, to_str=str) + ] + } + gens: list[Generator] = self._get_generator_proposals() + table_name = self.table_name() + for argument in args: + if argument.isdigit(): + n = int(argument) + if 0 < n <= len(gens): + gen = gens[n - 1] + comparison[f"{n}. {gen.name()}"] = gen.generate_data(limit) + self._print_values_queried(table_name, n, gen) + self.print_table_by_columns(comparison) + + def do_c(self, arg: str) -> None: + """Synonym for compare.""" + self.do_compare(arg) + + def _print_values_queried(self, table_name: str, n: int, gen: Generator) -> None: + """ + Print the values queried from the database for this generator. + + :param table_name: The name of the table the generator applies to. + :param n: A number to print at the start of the output. + :param gen: The generator to report. + """ + if not gen.select_aggregate_clauses() and not gen.custom_queries(): + self.print( + "{0}. {1} requires no data from the source database.", + n, + gen.name(), + ) + else: + self.print( + "{0}. {1} requires the following data from the source database:", + n, + gen.name(), + ) + self._print_select_aggregate_query(table_name, gen) + self._print_custom_queries(gen) + + def _print_custom_queries(self, gen: Generator) -> None: + """ + Print all the custom queries and all the values they get in this case. + + :param gen: The generator to print the custom queries for. + """ + cqs = gen.custom_queries() + if not cqs: + return + cq_key2args: dict[str, Any] = {} + nominal = gen.nominal_kwargs() + actual = gen.actual_kwargs() + self._get_custom_queries_from( + cq_key2args, + nominal, + actual, + ) + for cq_key, cq in cqs.items(): + self.print( + "{0}; providing the following values: {1}", + cq["query"], + cq_key2args[cq_key], + ) + + def _get_custom_queries_from( + self, out: dict[str, Any], nominal: Any, actual: Any + ) -> None: + if isinstance(nominal, str): + src_stat_groups = self.SRC_STAT_RE.search(nominal) + # Do we have a SRC_STAT reference? + if src_stat_groups: + # Get its name + cq_key = src_stat_groups.group(1) + # Are we pulling a specific part of this result? + sub = src_stat_groups.group(3) + if sub: + actual = {sub: actual} + else: + out[cq_key] = actual + elif isinstance(nominal, Sequence) and isinstance(actual, Sequence): + for i in range(min(len(nominal), len(actual))): + self._get_custom_queries_from(out, nominal[i], actual[i]) + elif isinstance(nominal, Mapping) and isinstance(actual, Mapping): + for k, v in nominal.items(): + if k in actual: + self._get_custom_queries_from(out, v, actual[k]) + + def _get_aggregate_query( + self, gens: list[Generator], table_name: str + ) -> str | None: + clauses = [ + f'{q["clause"]} AS {n}' + for gen in gens + for n, q in or_default(gen.select_aggregate_clauses(), {}).items() + ] + if not clauses: + return None + return f"SELECT {', '.join(clauses)} FROM {table_name}" + + def _print_select_aggregate_query(self, table_name: str, gen: Generator) -> None: + """ + Print the select aggregate query and all the values it gets in this case. + + This is not the entire query that will be executed, but only the part of it + that is required by a certain generator. + :param table_name: The table name. + :param gen: The generator to limit the aggregate query to. + """ + sacs = gen.select_aggregate_clauses() + if not sacs: + return + kwa = gen.actual_kwargs() + vals = [] + src_stat2kwarg = {v: k for k, v in gen.nominal_kwargs().items()} + for n in sacs.keys(): + src_stat = f'SRC_STATS["auto__{table_name}"]["results"][0]["{n}"]' + if src_stat in src_stat2kwarg: + ak = src_stat2kwarg[src_stat] + if ak in kwa: + vals.append(kwa[ak]) + else: + logger.warning( + "actual_kwargs for %s does not report %s", gen.name(), ak + ) + else: + logger.warning( + ( + "nominal_kwargs for %s does not have a value" + ' SRC_STATS["auto__%s"]["results"][0]["%s"]' + ), + gen.name(), + table_name, + n, + ) + select_q = self._get_aggregate_query([gen], table_name) + self.print("{0}; providing the following values: {1}", select_q, vals) + + def _get_column_data( + self, count: int, to_str: Callable[[Any], str] = repr + ) -> list[list[str]]: + columns = self._get_column_names() + columns_string = ", ".join(columns) + pred = " AND ".join(f"{column} IS NOT NULL" for column in columns) + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT {columns_string} FROM {self.table_name()}" + f" WHERE {pred} ORDER BY RANDOM() LIMIT {count}" + ) + ) + return [[to_str(x) for x in xs] for xs in result.all()] + + def do_propose(self, _arg: str) -> None: + """ + Display a list of possible generators for this column. + + They will be listed in order of fit, the most likely matches first. + The results can be compared (against a sample of the real data in + the column and against each other) with the 'compare' command. + """ + limit = 5 + gens = self._get_generator_proposals() + sample = self._get_column_data(limit) + if sample: + rep = [x[0] if len(x) == 1 else ",".join(x) for x in sample] + self.print(self.PROPOSE_SOURCE_SAMPLE_TEXT, "; ".join(rep)) + else: + self.print(self.PROPOSE_SOURCE_EMPTY_TEXT) + if not gens: + self.print(self.PROPOSE_NOTHING) + for index, gen in enumerate(gens): + fit = gen.fit(-1) + if fit == -1: + fit_s = "(no fit)" + elif fit < 100: + fit_s = f"(fit: {fit:.3g})" + else: + fit_s = f"(fit: {fit:.0f})" + self.print( + self.PROPOSE_GENERATOR_SAMPLE_TEXT, + index=index + 1, + name=gen.name(), + fit=fit_s, + sample="; ".join(map(repr, gen.generate_data(limit))), + ) + + def do_p(self, arg: str) -> None: + """Synonym for propose.""" + self.do_propose(arg) + + def get_proposed_generator_by_name(self, gen_name: str) -> Generator | None: + """Find a generator by name from the list of proposals.""" + for gen in self._get_generator_proposals(): + if gen.name() == gen_name: + return gen + return None + + def do_set(self, arg: str) -> None: + """Set one of the proposals as a generator.""" + if arg.isdigit() and not self._generators_valid(): + self.print("Please run 'propose' before 'set '") + return + gens = self._get_generator_proposals() + new_gen: Generator | None + if arg.isdigit(): + index = int(arg) + if index < 1: + self.print("set's integer argument must be at least 1") + return + if len(gens) < index: + self.print( + "There are currently only {0} generators proposed, please select one of them.", + len(gens), + ) + return + new_gen = gens[index - 1] + else: + new_gen = self.get_proposed_generator_by_name(arg) + if new_gen is None: + self.print("'{0}' is not an appropriate generator for this column", arg) + return + self.set_generator(new_gen) + self._go_next() + + def set_generator(self, gen: Generator | None) -> None: + """Set the current column's generator.""" + (table, gen_info) = self._get_table_and_generator() + if table is None: + self.print("Error: no table") + return + if gen_info is None: + self.print("Error: no column") + return + gen_info.gen = gen + + def do_s(self, arg: str) -> None: + """Synonym for set.""" + self.do_set(arg) + + def do_unset(self, _arg: str) -> None: + """Remove any generator set for this column.""" + self.set_generator(None) + self._go_next() + + def do_merge(self, arg: str) -> None: + """ + Add this column(s) to the specified column(s). + + After this, one generator will cover them all. + """ + cols = arg.split() + if not cols: + self.print("Error: merge requires a column argument") + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + self.print(self.ERROR_NO_SUCH_TABLE) + return + cols_available = functools.reduce( + lambda x, y: x | y, + [frozenset(gen.columns) for gen in table_entry.new_generators], + ) + cols_to_merge = frozenset(cols) + unknown_cols = cols_to_merge - cols_available + if unknown_cols: + for uc in unknown_cols: + self.print(self.ERROR_NO_SUCH_COLUMN, uc) + return + gen_info = table_entry.new_generators[self.generator_index] + current_columns = frozenset(gen_info.columns) + stated_current_columns = cols_to_merge & current_columns + if stated_current_columns: + for c in stated_current_columns: + self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) + return + # Remove cols_to_merge from each generator + new_new_generators: list[GeneratorInfo] = [] + for gen in table_entry.new_generators: + if gen is gen_info: + # Add columns to this generator + self.generator_index = len(new_new_generators) + new_new_generators.append( + GeneratorInfo( + columns=gen.columns + cols, + gen=None, + ) + ) + else: + # Remove columns if applicable + new_columns = [c for c in gen.columns if c not in cols_to_merge] + is_changed = len(new_columns) != len(gen.columns) + if new_columns: + # We have not removed this generator completely + new_new_generators.append( + GeneratorInfo( + columns=new_columns, + gen=None if is_changed else gen.gen, + ) + ) + table_entry.new_generators = new_new_generators + self.set_prompt() + + def complete_merge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Complete column names.""" + last_arg = text.split()[-1] + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return [] + return [ + column + for i, gen in enumerate(table_entry.new_generators) + if i != self.generator_index + for column in gen.columns + if column.startswith(last_arg) + ] + + def do_unmerge(self, arg: str) -> None: + """Remove this column(s) from this generator, make them a separate generator.""" + cols = arg.split() + if not cols: + self.print("Error: merge requires a column argument") + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + self.print(self.ERROR_NO_SUCH_TABLE) + return + gen_info = table_entry.new_generators[self.generator_index] + current_columns = frozenset(gen_info.columns) + cols_to_unmerge = frozenset(cols) + unknown_cols = cols_to_unmerge - current_columns + if unknown_cols: + for uc in unknown_cols: + self.print(self.ERROR_NO_SUCH_COLUMN, uc) + return + stated_unmerged_columns = cols_to_unmerge - current_columns + if stated_unmerged_columns: + for c in stated_unmerged_columns: + self.print(self.ERROR_COLUMN_ALREADY_UNMERGED, c) + return + if cols_to_unmerge == current_columns: + self.print(self.ERROR_CANNOT_UNMERGE_ALL) + return + # Remove unmerged columns + for um in cols_to_unmerge: + gen_info.columns.remove(um) + # The existing generator will not work + gen_info.gen = None + # And put them into a new (empty) generator + table_entry.new_generators.insert( + self.generator_index + 1, + GeneratorInfo( + columns=cols, + gen=None, + ), + ) + self.set_prompt() + + def complete_unmerge( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Complete column names to unmerge.""" + last_arg = text.split()[-1] + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return [] + return [ + column + for column in table_entry.new_generators[self.generator_index].columns + if column.startswith(last_arg) + ] diff --git a/datafaker/interactive/missingness.py b/datafaker/interactive/missingness.py new file mode 100644 index 0000000..2737e2a --- /dev/null +++ b/datafaker/interactive/missingness.py @@ -0,0 +1,355 @@ +"""Missingness configuration shell.""" +from dataclasses import dataclass +from collections.abc import Iterable, Mapping, MutableMapping +import re +from typing import cast + +from sqlalchemy import MetaData + +from datafaker.interactive.base import DbCmd, TableEntry + +@dataclass +class MissingnessType: + """The functions required for applying missingness.""" + + SAMPLED = "column_presence.sampled" + SAMPLED_QUERY = ( + "SELECT COUNT(*) AS row_count, {result_names} FROM " + "(SELECT {column_is_nulls} FROM {table} ORDER BY RANDOM() LIMIT {count})" + " AS __t GROUP BY {result_names}" + ) + name: str + query: str + comment: str + columns: list[str] + + @classmethod + def sampled_query(cls, table: str, count: int, column_names: Iterable[str]) -> str: + """ + Construct a query to make a sampling of the named rows of the table. + + :param table: The name of the table to sample. + :param count: The number of samples to get. + :param column_names: The columns to fetch. + :return: The SQL query to do the sampling. + """ + result_names = ", ".join([f"{c}__is_null" for c in column_names]) + column_is_nulls = ", ".join( + [f"{c} IS NULL AS {c}__is_null" for c in column_names] + ) + return cls.SAMPLED_QUERY.format( + result_names=result_names, + column_is_nulls=column_is_nulls, + table=table, + count=count, + ) + + +@dataclass +class MissingnessCmdTableEntry(TableEntry): + """Table entry for the missingness command shell.""" + + old_type: MissingnessType + new_type: MissingnessType | None + + +class MissingnessCmd(DbCmd): + """ + Interactive shell for the user to set missingness. + + Can only be used for Missingness Completely At Random. + """ + + intro = "Interactive missingness configuration. Type ? for help.\n" + doc_leader = """Use commands 'sampled' and 'none' to choose the missingness style for +the current table. Use commands 'next' and 'previous' to change the +current table. Use 'tables' to list the tables and 'count' to show +how many NULLs exist in each column. Use 'peek' or 'select' to see +data from the database. Use 'quit' to exit this tool.""" + prompt = "(missingness) " + file = None + PATTERN_RE = re.compile(r'SRC_STATS\["([^"]*)"\]') + + def find_missingness_query( + self, missingness_generator: Mapping + ) -> tuple[str, str] | None: + """Find query and comment from src-stats for the passed missingness generator.""" + kwargs = missingness_generator.get("kwargs", {}) + patterns = kwargs.get("patterns", "") + pattern_match = self.PATTERN_RE.match(patterns) + if pattern_match: + key = pattern_match.group(1) + for src_stat in self.config["src-stats"]: + if src_stat.get("name") == key: + query = src_stat.get("query", None) + if not isinstance(query, str): + return None + return (query, src_stat.get("comment", "")) + return None + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> MissingnessCmdTableEntry | None: + """ + Make a table entry for a particular table. + + :param name: The name of the table to make an entry for. + :param table: The part of ``config.yaml`` relating to this table. + :return: The newly-constructed table entry. + """ + if table_config.get("ignore", False): + return None + if table_config.get("vocabulary_table", False): + return None + if table_config.get("num_rows_per_pass", 1) == 0: + return None + mgs = table_config.get("missingness_generators", []) + old = None + nonnull_columns = self.get_nonnull_columns(table_name) + if not nonnull_columns: + return None + if not mgs: + old = MissingnessType( + name="none", + query="", + comment="", + columns=[], + ) + elif len(mgs) == 1: + mg = mgs[0] + mg_name = mg.get("name", None) + if isinstance(mg_name, str): + query_comment = self.find_missingness_query(mg) + if query_comment is not None: + (query, comment) = query_comment + old = MissingnessType( + name=mg_name, + query=query, + comment=comment, + columns=mg.get("columns_assigned", []), + ) + if old is None: + return None + return MissingnessCmdTableEntry( + name=table_name, + old_type=old, + new_type=old, + ) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping, + ): + """ + Initialise a MissingnessCmd. + + :param src_dsn: connection string for the source database. + :param src_schema: schema name for the source database. + :param metadata: SQLAlchemy metadata for the source database. + :param config: Configuration from the ``config.yaml`` file. + """ + super().__init__(src_dsn, src_schema, metadata, config) + self.set_prompt() + + @property + def table_entries(self) -> list[MissingnessCmdTableEntry]: + """Get the table entries list.""" + return cast(list[MissingnessCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name( + self, table_name: str + ) -> MissingnessCmdTableEntry | None: + """Find the table entry given the table name.""" + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(MissingnessCmdTableEntry, entry) + + def set_prompt(self) -> None: + """Set the prompt according to the current table and missingness.""" + if self.table_index < len(self.table_entries): + entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + nt = entry.new_type + if nt is None: + self.prompt = f"(missingness for {entry.name}) " + else: + self.prompt = f"(missingness for {entry.name}: {nt.name}) " + else: + self.prompt = "(missingness) " + + def set_type(self, t_type: MissingnessType) -> None: + """Set the missingness of the current table.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + entry.new_type = t_type + + def _copy_entries(self) -> None: + """Set the new missingness into the configuration.""" + src_stats = self._remove_prefix_src_stats("missing_auto__") + for entry in self.table_entries: + table = self.get_table_config(entry.name) + if entry.new_type is None or entry.new_type.name == "none": + table.pop("missingness_generators", None) + else: + src_stat_key = f"missing_auto__{entry.name}__0" + table["missingness_generators"] = [ + { + "name": entry.new_type.name, + "kwargs": { + "patterns": f'SRC_STATS["{src_stat_key}"]["results"]' + }, + "columns": entry.new_type.columns, + } + ] + src_stats.append( + { + "name": src_stat_key, + "query": entry.new_type.query, + "comments": [] + if entry.new_type.comment is None + else [entry.new_type.comment], + } + ) + self.set_table_config(entry.name, table) + + def do_quit(self, _arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + if entry.old_type != entry.new_type: + count += 1 + if entry.old_type is None: + self.print( + "Putting generator {0} on table {1}", + entry.name, + entry.new_type.name, + ) + elif entry.new_type is None: + self.print( + "Deleting generator {1} from table {0}", + entry.name, + entry.old_type.name, + ) + else: + self.print( + "Changing {0} from {1} to {2}", + entry.name, + entry.old_type.name, + entry.new_type.name, + ) + if count == 0: + self.print("You have made no changes.") + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables with their types.""" + for entry in self.table_entries: + old = "-" if entry.old_type is None else entry.old_type.name + new = "-" if entry.new_type is None else entry.new_type.name + desc = new if old == new else f"{old}->{new}" + self.print("{0} {1}", entry.name, desc) + + def do_next(self, arg: str) -> None: + """ + Go to the next table, or a specified table. + + 'next' = go to the next table, 'next tablename' = go to table 'tablename' + """ + if arg: + # Find the index of the table called _arg, if any + index = next( + (i for i, entry in enumerate(self.table_entries) if entry.name == arg), + None, + ) + if index is None: + self.print(self.ERROR_NO_SUCH_TABLE, arg) + return + self._set_table_index(index) + return + self.next_table(self.INFO_NO_MORE_TABLES) + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Get completions for tables and columns.""" + return [ + entry.name for entry in self.table_entries if entry.name.startswith(text) + ] + + def do_previous(self, _arg: str) -> None: + """Go to the previous table.""" + if not self._set_table_index(self.table_index - 1): + self.print(self.ERROR_ALREADY_AT_START) + + def _set_type(self, name: str, query: str, comment: str) -> None: + """Set the current table entry's query.""" + if len(self.table_entries) <= self.table_index: + return + entry: MissingnessCmdTableEntry = self.table_entries[self.table_index] + entry.new_type = MissingnessType( + name=name, + query=query, + comment=comment, + columns=self.get_nonnull_columns(entry.name), + ) + + def _set_none(self) -> None: + """Set the current table to have no missingness applied.""" + if len(self.table_entries) <= self.table_index: + return + self.table_entries[self.table_index].new_type = None + + def do_sampled(self, arg: str) -> None: + """ + Set the current table missingness as 'sampled', and go to the next table. + + 'sampled 3000' means sample 3000 rows at random and choose the + missingness to be the same as one of those 3000 at random. + 'sampled' means the same, but with a default number of rows sampled (1000). + """ + if len(self.table_entries) <= self.table_index: + self.print("Error! not on a table") + return + entry = self.table_entries[self.table_index] + if arg == "": + count = 1000 + elif arg.isdecimal(): + count = int(arg) + else: + self.print( + ( + "Error: sampled can be used alone or with" + " an integer argument. {0} is not permitted" + ), + arg, + ) + return + self._set_type( + MissingnessType.SAMPLED, + MissingnessType.sampled_query( + entry.name, + count, + self.get_nonnull_columns(entry.name), + ), + ( + "The missingness patterns and how often they appear in a" + f" sample of {count} from table {entry.name}" + ), + ) + self.print("Table {} set to sampled missingness", self.table_name()) + self.next_table() + + def do_none(self, _arg: str) -> None: + """Set the current table to have no missingness, and go to the next table.""" + self._set_none() + self.print("Table {} set to have no missingness", self.table_name()) + self.next_table() diff --git a/datafaker/interactive/table.py b/datafaker/interactive/table.py new file mode 100644 index 0000000..c23bfdf --- /dev/null +++ b/datafaker/interactive/table.py @@ -0,0 +1,376 @@ +"""Table configuration command shell.""" +from collections.abc import Mapping, MutableMapping, Sequence +from dataclasses import dataclass + +import sqlalchemy +from sqlalchemy import MetaData +from typing import Any, cast + +from datafaker.interactive.base import TableType, DbCmd, TableEntry, TYPE_LETTER, TYPE_PROMPT + +@dataclass +class TableCmdTableEntry(TableEntry): + """Table entry for the table command shell.""" + + old_type: TableType + new_type: TableType + + +class TableCmd(DbCmd): + """Command shell allowing the user to set the type of each table.""" + + intro = ( + "Interactive table configuration (ignore," + " vocabulary, private, generate or empty). Type ? for help.\n" + ) + doc_leader = """Use the commands 'ignore', 'vocabulary', +'private', 'empty' or 'generate' to set the table's type. Use 'next' or +'previous' to change table. Use 'tables' and 'columns' for +information about the database. Use 'data', 'peek', 'select' or +'count' to see some data contained in the current table. Use 'quit' +to exit this program.""" + prompt = "(tableconf) " + file = None + WARNING_TEXT_VOCAB_TO_NON_VOCAB = ( + "Vocabulary table {0} references non-vocabulary table {1}" + ) + WARNING_TEXT_NON_EMPTY_TO_EMPTY = ( + "Empty table {1} referenced from non-empty table {0}. {1} will need stories." + ) + WARNING_TEXT_PROBLEMS_EXIST = "WARNING: The following table types have problems:" + WARNING_TEXT_POTENTIAL_PROBLEMS = ( + "NOTE: The following table types might cause problems later:" + ) + NOTE_TEXT_NO_CHANGES = "You have made no changes." + NOTE_TEXT_CHANGING = "Changing {0} from {1} to {2}" + + def make_table_entry( + self, table_name: str, table_config: Mapping + ) -> TableCmdTableEntry | None: + """ + Make a table entry for the named table. + + :param name: The name of the table. + :param table: The part of ``config.yaml`` corresponding to this table. + :return: The newly-constructed table entry. + """ + if table_config.get("ignore", False): + return TableCmdTableEntry(table_name, TableType.IGNORE, TableType.IGNORE) + if table_config.get("vocabulary_table", False): + return TableCmdTableEntry( + table_name, TableType.VOCABULARY, TableType.VOCABULARY + ) + if table_config.get("primary_private", False): + return TableCmdTableEntry(table_name, TableType.PRIVATE, TableType.PRIVATE) + if table_config.get("num_rows_per_pass", 1) == 0: + return TableCmdTableEntry(table_name, TableType.EMPTY, TableType.EMPTY) + return TableCmdTableEntry(table_name, TableType.GENERATE, TableType.GENERATE) + + def __init__( + self, + src_dsn: str, + src_schema: str | None, + metadata: MetaData, + config: MutableMapping[str, Any], + ) -> None: + """Initialise a TableCmd.""" + super().__init__(src_dsn, src_schema, metadata, config) + self.set_prompt() + + @property + def table_entries(self) -> list[TableCmdTableEntry]: + """Get the list of table entries.""" + return cast(list[TableCmdTableEntry], self._table_entries) + + def _find_entry_by_table_name(self, table_name: str) -> TableCmdTableEntry | None: + """Get the table entry of the table with the given name.""" + entry = super()._find_entry_by_table_name(table_name) + if entry is None: + return None + return cast(TableCmdTableEntry, entry) + + def set_prompt(self) -> None: + """Set the prompt according to the current table and its type.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + self.prompt = TYPE_PROMPT[entry.new_type].format(entry.name) + else: + self.prompt = "(table) " + + def set_type(self, t_type: TableType) -> None: + """Set the type of the current table.""" + if self.table_index < len(self.table_entries): + entry = self.table_entries[self.table_index] + entry.new_type = t_type + + def _copy_entries(self) -> None: + """Alter the configuration to match the new table entries.""" + for entry in self.table_entries: + if entry.old_type != entry.new_type: + table = self.get_table_config(entry.name) + if ( + entry.old_type == TableType.EMPTY + and table.get("num_rows_per_pass", 1) == 0 + ): + table["num_rows_per_pass"] = 1 + if entry.new_type == TableType.IGNORE: + table["ignore"] = True + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + elif entry.new_type == TableType.VOCABULARY: + table.pop("ignore", None) + table["vocabulary_table"] = True + table.pop("primary_private", None) + elif entry.new_type == TableType.PRIVATE: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table["primary_private"] = True + elif entry.new_type == TableType.EMPTY: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + table["num_rows_per_pass"] = 0 + else: + table.pop("ignore", None) + table.pop("vocabulary_table", None) + table.pop("primary_private", None) + self.set_table_config(entry.name, table) + + def _get_referenced_tables(self, from_table_name: str) -> set[str]: + """Get all the tables referenced by this table's foreign keys.""" + from_meta = self.metadata.tables[from_table_name] + return { + fk.column.table.name for col in from_meta.columns for fk in col.foreign_keys + } + + def _sanity_check_failures(self) -> list[tuple[str, str, str]]: + """Find tables that reference each other that should not given their types.""" + failures = [] + for from_entry in self.table_entries: + from_t = from_entry.new_type + if from_t == TableType.VOCABULARY: + referenced = self._get_referenced_tables(from_entry.name) + for ref in referenced: + to_entry = self._find_entry_by_table_name(ref) + if ( + to_entry is not None + and to_entry.new_type != TableType.VOCABULARY + ): + failures.append( + ( + self.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + from_entry.name, + to_entry.name, + ) + ) + return failures + + def _sanity_check_warnings(self) -> list[tuple[str, str, str]]: + """Find tables that reference each other that might cause problems given their types.""" + warnings = [] + for from_entry in self.table_entries: + from_t = from_entry.new_type + if from_t in {TableType.GENERATE, TableType.PRIVATE}: + referenced = self._get_referenced_tables(from_entry.name) + for ref in referenced: + to_entry = self._find_entry_by_table_name(ref) + if to_entry is not None and to_entry.new_type in { + TableType.EMPTY, + TableType.IGNORE, + }: + warnings.append( + ( + self.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + from_entry.name, + to_entry.name, + ) + ) + return warnings + + def do_quit(self, _arg: str) -> bool: + """Check the updates, save them if desired and quit the configurer.""" + count = 0 + for entry in self.table_entries: + if entry.old_type != entry.new_type: + count += 1 + self.print( + self.NOTE_TEXT_CHANGING, + entry.name, + entry.old_type.value, + entry.new_type.value, + ) + if count == 0: + self.print(self.NOTE_TEXT_NO_CHANGES) + failures = self._sanity_check_failures() + if failures: + self.print(self.WARNING_TEXT_PROBLEMS_EXIST) + for text, from_t, to_t in failures: + self.print(text, from_t, to_t) + warnings = self._sanity_check_warnings() + if warnings: + self.print(self.WARNING_TEXT_POTENTIAL_PROBLEMS) + for text, from_t, to_t in warnings: + self.print(text, from_t, to_t) + reply = self.ask_save() + if reply == "yes": + self._copy_entries() + return True + if reply == "no": + return True + return False + + def do_tables(self, _arg: str) -> None: + """List the tables with their types.""" + for entry in self.table_entries: + old = entry.old_type + new = entry.new_type + becomes = " " if old == new else "->" + TYPE_LETTER[new] + self.print("{0}{1} {2}", TYPE_LETTER[old], becomes, entry.name) + + def do_next(self, arg: str) -> None: + """'next' = go to the next table, 'next tablename' = go to table 'tablename'.""" + if arg: + # Find the index of the table called _arg, if any + index = self.find_entry_index_by_table_name(arg) + if index is None: + self.print(self.ERROR_NO_SUCH_TABLE, arg) + return + self._set_table_index(index) + return + self.next_table(self.INFO_NO_MORE_TABLES) + + def complete_next( + self, text: str, _line: str, _begidx: int, _endidx: int + ) -> list[str]: + """Get the completions for tables and columns.""" + return [ + entry.name for entry in self.table_entries if entry.name.startswith(text) + ] + + def do_previous(self, _arg: str) -> None: + """Go to the previous table.""" + if not self._set_table_index(self.table_index - 1): + self.print(self.ERROR_ALREADY_AT_START) + + def do_ignore(self, _arg: str) -> None: + """Set the current table as ignored, and go to the next table.""" + self.set_type(TableType.IGNORE) + self.print("Table {} set as ignored", self.table_name()) + self.next_table() + + def do_vocabulary(self, _arg: str) -> None: + """Set the current table as a vocabulary table, and go to the next table.""" + self.set_type(TableType.VOCABULARY) + self.print("Table {} set to be a vocabulary table", self.table_name()) + self.next_table() + + def do_private(self, _arg: str) -> None: + """Set the current table as a primary private table (such as the table of patients).""" + self.set_type(TableType.PRIVATE) + self.print("Table {} set to be a primary private table", self.table_name()) + self.next_table() + + def do_generate(self, _arg: str) -> None: + """Set the current table as to be generated, and go to the next table.""" + self.set_type(TableType.GENERATE) + self.print("Table {} generate", self.table_name()) + self.next_table() + + def do_empty(self, _arg: str) -> None: + """Set the current table as empty; no generators will be run for it.""" + self.set_type(TableType.EMPTY) + self.print("Table {} empty", self.table_name()) + self.next_table() + + def do_columns(self, _arg: str) -> None: + """Report the column names and metadata.""" + self.report_columns() + + def do_data(self, arg: str) -> None: + """ + Report some data. + + 'data' = report a random ten lines, + 'data 20' = report a random 20 lines, + 'data 20 ColumnName' = report a random twenty entries from ColumnName, + 'data 20 ColumnName 30' = report a random twenty entries from + ColumnName of length at least 30, + """ + args = arg.split() + column = None + number = None + arg_index = 0 + min_length = 0 + table_metadata = self.table_metadata() + if arg_index < len(args) and args[arg_index].isdigit(): + number = int(args[arg_index]) + arg_index += 1 + if arg_index < len(args) and args[arg_index] in table_metadata.columns: + column = args[arg_index] + arg_index += 1 + if arg_index < len(args) and args[arg_index].isdigit(): + min_length = int(args[arg_index]) + arg_index += 1 + if arg_index != len(args): + self.print( + """Did not understand these arguments +The format is 'data [entries] [column-name [minimum-length]]' where [] means optional text. +Type 'columns' to find out valid column names for this table. +Type 'help data' for examples.""" + ) + return + if column is None: + if number is None: + number = 10 + self.print_row_data(number) + else: + if number is None: + number = 48 + self.print_column_data(column, number, min_length) + + def complete_data( + self, text: str, line: str, begidx: int, _endidx: int + ) -> list[str]: + """Get completions for arguments to ``data``.""" + previous_parts = line[: begidx - 1].split() + if len(previous_parts) != 2: + return [] + table_metadata = self.table_metadata() + return [k for k in table_metadata.columns.keys() if k.startswith(text)] + + def print_column_data(self, column: str, count: int, min_length: int) -> None: + """ + Print a sample of data from a certain column of the current table. + + :param column: The name of the column to report on. + :param count: The number of rows to sample. + :param min_length: The minimum length of text to choose from (0 for any text). + """ + where = f"WHERE {column} IS NOT NULL" + if 0 < min_length: + where = f"WHERE LENGTH({column}) >= {min_length}" + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT {column} FROM {self.table_name()}" + f" {where} ORDER BY RANDOM() LIMIT {count}" + ) + ) + self.columnize([str(x[0]) for x in result.all()]) + + def print_row_data(self, count: int) -> None: + """ + Print a sample or rows from the current table. + + :param count: The number of rows to report. + """ + with self.sync_engine.connect() as connection: + result = connection.execute( + sqlalchemy.text( + f"SELECT * FROM {self.table_name()} ORDER BY RANDOM() LIMIT {count}" + ) + ) + if result is None: + self.print("No rows in this table!") + return + self.print_results(result) diff --git a/datafaker/main.py b/datafaker/main.py index cf7bd3b..22cf0ef 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -153,13 +153,10 @@ def create_data( config = read_config_file(config_file) if config_file is not None else {} orm_metadata = load_metadata_for_output(orm_file, config) df_module = import_file(df_file) - table_generator_dict = df_module.table_generator_dict - story_generator_list = df_module.story_generator_list try: row_counts = create_db_data( sorted_non_vocabulary_tables(orm_metadata, config), - table_generator_dict, - story_generator_list, + df_module, num_passes, ) logger.debug( diff --git a/datafaker/utils.py b/datafaker/utils.py index 6d0041d..3cc8c28 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -11,11 +11,11 @@ from types import ModuleType from typing import Any, Callable, Final, Generator, Iterable, Optional, TypeVar, Union +import psycopg2 import sqlalchemy import yaml from jsonschema.exceptions import ValidationError from jsonschema.validators import validate -from psycopg2.errors import UndefinedObject from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select from sqlalchemy.engine.interfaces import DBAPIConnection from sqlalchemy.exc import IntegrityError, ProgrammingError @@ -79,7 +79,7 @@ def import_file(file_path: str) -> ModuleType: """ spec = importlib.util.spec_from_file_location("df", file_path) if spec is None or spec.loader is None: - raise Exception(f"No loadable module at {file_path}") + raise ImportError(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module @@ -248,7 +248,7 @@ def emit(self, record: Any) -> None: sys.stdout.flush() except RecursionError: raise - except Exception: + except Exception: # pylint: disable=broad-exception-caught self.handleError(record) @@ -275,7 +275,7 @@ def emit(self, record: Any) -> None: sys.stderr.flush() except RecursionError: raise - except Exception: + except Exception: # pylint: disable=broad-exception-caught self.handleError(record) @@ -458,7 +458,8 @@ def remove_vocab_foreign_key_constraints( ) except ProgrammingError as e: session.rollback() - if isinstance(e.orig, UndefinedObject): + # pylint: disable=no-member + if isinstance(e.orig, psycopg2.errors.UndefinedObject): logger.debug("Constraint does not exist") else: raise e diff --git a/tests/test_interactive.py b/tests/test_interactive_generators.py similarity index 70% rename from tests/test_interactive.py rename to tests/test_interactive_generators.py index fcb5ced..26d08aa 100644 --- a/tests/test_interactive.py +++ b/tests/test_interactive_generators.py @@ -1,453 +1,18 @@ -""" Tests for the base module. """ +""" Tests for the configure-generators command. """ import copy -import random import re +from collections.abc import MutableMapping from dataclasses import dataclass -from typing import Any, Iterable, MutableMapping +from typing import Any, Iterable from unittest import TestCase from unittest.mock import MagicMock, Mock, patch from sqlalchemy import Connection, MetaData, insert, select from datafaker.generators import NullPartitionedNormalGeneratorFactory -from datafaker.interactive import ( - DbCmd, - GeneratorCmd, - MissingnessCmd, - TableCmd, - update_config_generators, -) -from tests.utils import GeneratesDBTestCase, RequiresDBTestCase - - -class TestDbCmdMixin(DbCmd): - """A mixin for capturing output from interactive commands.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize a TestDbCmdMixin""" - super().__init__(*args, **kwargs) - self.reset() - - def reset(self) -> None: - """Reset all the debug messages collected so far.""" - self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] - self.headings: list[str] = [] - self.rows: list[list[str]] = [] - self.column_items: list[list[str]] = [] - self.columns: dict[str, list[Any]] = {} - - def print(self, text: str, *args: Any, **kwargs: Any) -> None: - """Capture the printed message.""" - self.messages.append((text, args, kwargs)) - - def print_table(self, headings: list[str], rows: list[list[str]]) -> None: - """Capture the printed table.""" - self.headings = headings - self.rows = rows - - def print_table_by_columns(self, columns: dict[str, list[str]]) -> None: - """Capture the printed table.""" - self.columns = columns - - # pylint: disable=arguments-renamed - def columnize(self, items: list[str] | None, _displaywidth: int = 80) -> None: - """Capture the printed table.""" - if items is not None: - self.column_items.append(items) - - def ask_save(self) -> str: - """Quitting always works without needing to ask the user.""" - return "yes" - - -class TestTableCmd(TableCmd, TestDbCmdMixin): - """TableCmd but mocked""" - - -class ConfigureTablesTests(RequiresDBTestCase): - """Testing configure-tables.""" - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: - return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) - - -class ConfigureTablesSrcTests(ConfigureTablesTests): - """Testing configure-tables with src.dump.""" - - dump_file_path = "src.dump" - database_name = "src" - schema_name = "public" - - def test_table_name_prompts(self) -> None: - """Test that the prompts follow the names of the tables.""" - config: MutableMapping[str, Any] = {} - with self._get_cmd(config) as tc: - table_names = list(self.metadata.tables.keys()) - for t in table_names: - self.assertIn(t, tc.prompt) - tc.do_next("") - self.assertListEqual(tc.messages, [(TableCmd.INFO_NO_MORE_TABLES, (), {})]) - tc.reset() - for t in reversed(table_names): - self.assertIn(t, tc.prompt) - tc.do_previous("") - self.assertListEqual( - tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] - ) - tc.reset() - bad_table_name = "notarealtable" - tc.do_next(bad_table_name) - self.assertListEqual( - tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] - ) - tc.reset() - good_table_name = table_names[2] - tc.do_next(good_table_name) - self.assertListEqual(tc.messages, []) - self.assertIn(good_table_name, tc.prompt) - - def test_column_display(self) -> None: - """Test that we can see the names of the columns.""" - config: MutableMapping[str, Any] = {} - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_columns("") - self.assertListEqual( - tc.rows, - [ - ["id", "INTEGER", True, False, ""], - ["a", "BOOLEAN", False, False, ""], - ["b", "BOOLEAN", False, False, ""], - ["c", "TEXT", False, False, ""], - ], - ) - - def test_null_configuration(self) -> None: - """A table still works if its configuration is None.""" - config = { - "tables": None, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_private("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse( - tables["unique_constraint_test"].get("vocabulary_table", False) - ) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue( - tables["unique_constraint_test"].get("primary_private", False) - ) - - def test_null_table_configuration(self) -> None: - """A table still works if its configuration is None.""" - config = { - "tables": { - "unique_constraint_test": None, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_private("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse( - tables["unique_constraint_test"].get("vocabulary_table", False) - ) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertTrue( - tables["unique_constraint_test"].get("primary_private", False) - ) - - def test_configure_tables(self) -> None: - """Test that we can change columns to ignore, vocab or generate.""" - config = { - "tables": { - "unique_constraint_test": { - "vocabulary_table": True, - }, - "no_pk_test": { - "ignore": True, - }, - "hospital_visit": { - "num_passes": 0, - }, - "empty_vocabulary": { - "private": True, - }, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_generate("") - tc.do_next("person") - tc.do_vocabulary("") - tc.do_next("mitigation_type") - tc.do_ignore("") - tc.do_next("hospital_visit") - tc.do_private("") - tc.do_quit("") - tc.do_next("empty_vocabulary") - tc.do_empty("") - tc.do_quit("") - tables = tc.config["tables"] - self.assertFalse( - tables["unique_constraint_test"].get("vocabulary_table", False) - ) - self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) - self.assertFalse( - tables["unique_constraint_test"].get("primary_private", False) - ) - self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) - self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) - self.assertTrue(tables["no_pk_test"].get("ignore", False)) - self.assertFalse(tables["no_pk_test"].get("primary_private", False)) - self.assertEqual(tables["no_pk_test"].get("num_rows_per_pass", 1), 1) - self.assertTrue(tables["person"].get("vocabulary_table", False)) - self.assertFalse(tables["person"].get("ignore", False)) - self.assertFalse(tables["person"].get("primary_private", False)) - self.assertEqual(tables["person"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["mitigation_type"].get("vocabulary_table", False)) - self.assertTrue(tables["mitigation_type"].get("ignore", False)) - self.assertFalse(tables["mitigation_type"].get("primary_private", False)) - self.assertEqual(tables["mitigation_type"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["hospital_visit"].get("vocabulary_table", False)) - self.assertFalse(tables["hospital_visit"].get("ignore", False)) - self.assertTrue(tables["hospital_visit"].get("primary_private", False)) - self.assertEqual(tables["hospital_visit"].get("num_rows_per_pass", 1), 1) - self.assertFalse(tables["empty_vocabulary"].get("vocabulary_table", False)) - self.assertFalse(tables["empty_vocabulary"].get("ignore", False)) - self.assertFalse(tables["empty_vocabulary"].get("primary_private", False)) - self.assertEqual(tables["empty_vocabulary"].get("num_rows_per_pass", 1), 0) - - def test_print_data(self) -> None: - """Test that we can print random rows from the table and random data from columns.""" - person_table = self.metadata.tables["person"] - with self.sync_engine.connect() as conn: - person_rows = conn.execute(select(person_table)).mappings().fetchall() - person_data = {row["person_id"]: row for row in person_rows} - name_set = {row["name"] for row in person_rows} - person_headings = ["person_id", "name", "research_opt_out", "stored_from"] - with self._get_cmd({}) as tc: - tc.do_next("person") - tc.do_data("") - self.assertListEqual(tc.headings, person_headings) - self.assertEqual(len(tc.rows), 10) # default number of rows is 10 - for row in tc.rows: - expected = person_data[row[0]] - self.assertListEqual(row, [expected[h] for h in person_headings]) - tc.reset() - rows_to_get_count = 6 - tc.do_data(str(rows_to_get_count)) - self.assertListEqual(tc.headings, person_headings) - self.assertEqual(len(tc.rows), rows_to_get_count) - for row in tc.rows: - expected = person_data[row[0]] - self.assertListEqual(row, [expected[h] for h in person_headings]) - tc.reset() - to_get_count = 12 - tc.do_data(f"{to_get_count} name") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(len(tc.column_items[0]), to_get_count) - self.assertLessEqual(set(tc.column_items[0]), name_set) - tc.reset() - tc.do_data(f"{to_get_count} name 12") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual(len(tc.column_items[0]), to_get_count) - tc.reset() - tc.do_data(f"{to_get_count} name 13") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual( - set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) - ) - tc.reset() - tc.do_data(f"{to_get_count} name 16") - self.assertEqual(len(tc.column_items), 1) - self.assertEqual( - set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) - ) - - def test_list_tables(self) -> None: - """Test that we can list the tables""" - config = { - "tables": { - "unique_constraint_test": { - "vocabulary_table": True, - }, - "no_pk_test": { - "ignore": True, - }, - }, - } - with self._get_cmd(config) as tc: - tc.do_next("unique_constraint_test") - tc.do_ignore("") - tc.do_next("person") - tc.do_vocabulary("") - tc.reset() - tc.do_tables("") - person_listed = False - unique_constraint_test_listed = False - no_pk_test_listed = False - for _text, args, _kwargs in tc.messages: - if args[2] == "person": - self.assertFalse(person_listed) - person_listed = True - self.assertEqual(args[0], "G") - self.assertEqual(args[1], "->V") - elif args[2] == "unique_constraint_test": - self.assertFalse(unique_constraint_test_listed) - unique_constraint_test_listed = True - self.assertEqual(args[0], "V") - self.assertEqual(args[1], "->I") - elif args[2] == "no_pk_test": - self.assertFalse(no_pk_test_listed) - no_pk_test_listed = True - self.assertEqual(args[0], "I") - self.assertEqual(args[1], " ") - else: - self.assertEqual(args[0], "G") - self.assertEqual(args[1], " ") - self.assertTrue(person_listed) - self.assertTrue(unique_constraint_test_listed) - self.assertTrue(no_pk_test_listed) - - -class ConfigureTablesInstrumentsTests(ConfigureTablesTests): - """Testing configure-tables with the instrument.sql database.""" - - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def test_sanity_checks_both(self) -> None: - """ - Test ``configure-tables`` sanity checks. - """ - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with self._get_cmd(config) as tc: - tc.reset() - tc.do_quit("") - self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) - self.assertEqual( - tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) - ) - self.assertEqual( - tc.messages[2], - ( - TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - ("model", "manufacturer"), - {}, - ), - ) - self.assertEqual( - tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) - ) - self.assertEqual( - tc.messages[4], - ( - TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - ("signature_model", "player"), - {}, - ), - ) - - def test_sanity_checks_warnings_only(self) -> None: - """ - Test ``configure-tables`` sanity checks. - """ - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: - tc.do_next("manufacturer") - tc.do_vocabulary("") - tc.reset() - tc.do_quit("") - self.assertEqual( - tc.messages[0], - ( - TableCmd.NOTE_TEXT_CHANGING, - ("manufacturer", "ignore", "vocabulary"), - {}, - ), - ) - self.assertEqual( - tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) - ) - self.assertEqual( - tc.messages[2], - ( - TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, - ("signature_model", "player"), - {}, - ), - ) - - def test_sanity_checks_errors_only(self) -> None: - """ - Test ``configure-tables`` sanity checks. - """ - config = { - "tables": { - "model": { - "vocabulary_table": True, - }, - "manufacturer": { - "ignore": True, - }, - "player": { - "num_rows_per_pass": 0, - }, - }, - } - with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: - tc.do_next("signature_model") - tc.do_empty("") - tc.reset() - tc.do_quit("") - self.assertEqual( - tc.messages[0], - ( - TableCmd.NOTE_TEXT_CHANGING, - ("signature_model", "generate", "empty"), - {}, - ), - ) - self.assertEqual( - tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) - ) - self.assertEqual( - tc.messages[2], - ( - TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, - ("model", "manufacturer"), - {}, - ), - ) +from datafaker.interactive import update_config_generators +from datafaker.interactive.generators import GeneratorCmd +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin class TestGeneratorCmd(GeneratorCmd, TestDbCmdMixin): @@ -1169,97 +734,6 @@ def test_create_with_weighted_choice(self) -> None: self.assertSetEqual(threes, {1, 2, 3, 4, 5}) -class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): - """MissingnessCmd but mocked""" - - -class ConfigureMissingnessTests(RequiresDBTestCase): - """Testing configure-missing.""" - - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: - """We are using configure-missingness.""" - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_set_missingness_to_sampled(self) -> None: - """Test that we can set one table to sampled missingness.""" - with self._get_cmd({}) as mc: - table = "signature_model" - mc.do_next(table) - mc.do_counts("") - self.assertListEqual( - mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] - ) - # Check the counts of NULLs in each column - self.assertListEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) - mc.do_sampled("") - mc.do_quit("") - self.assertListEqual( - mc.config["tables"][table]["missingness_generators"], - [ - { - "columns": ["player_id", "based_on"], - "kwargs": { - "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' - }, - "name": "column_presence.sampled", - } - ], - ) - self.assertEqual( - mc.config["src-stats"][0]["name"], - "missing_auto__signature_model__0", - ) - self.assertEqual( - mc.config["src-stats"][0]["query"], - ( - "SELECT COUNT(*) AS row_count," - " player_id__is_null, based_on__is_null FROM" - " (SELECT player_id IS NULL AS player_id__is_null," - " based_on IS NULL AS based_on__is_null FROM" - " signature_model ORDER BY RANDOM() LIMIT 1000)" - " AS __t GROUP BY player_id__is_null, based_on__is_null" - ), - ) - - -class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): - """Testing configure-missing with generation.""" - - dump_file_path = "instrument.sql" - database_name = "instrument" - schema_name = "public" - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: - return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_missingness(self) -> None: - """Test that we can sample real missingness and reproduce it.""" - random.seed(45) - # Configure the missingness - table_name = "signature_model" - with self._get_cmd({}) as mc: - mc.do_next(table_name) - mc.do_sampled("") - mc.do_quit("") - config = mc.config - self.generate_data(config, num_passes=100) - # Test that each missingness pattern is present in the database - with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).mappings().fetchall() - patterns: set[int] = set() - for row in rows: - p = 0 if row["player_id"] is None else 1 - b = 0 if row["based_on"] is None else 2 - patterns.add(p + b) - # all pattern possibilities should be present - self.assertSetEqual(patterns, {0, 1, 2, 3}) - - class GeneratorTests(GeneratesDBTestCase): """Testing configure-generators with generation.""" @@ -1445,6 +919,7 @@ def covar(self) -> float: return (self.xy - self.x * self.y / self.n) / (self.n - 1) +# pylint disable: too-many-instance-attributes class EavMeasurementTableStats: """The statistics for the Measurement table of eav.sql.""" diff --git a/tests/test_interactive_missingness.py b/tests/test_interactive_missingness.py new file mode 100644 index 0000000..7a63ea5 --- /dev/null +++ b/tests/test_interactive_missingness.py @@ -0,0 +1,100 @@ +""" Tests for the configure-missingness command. """ +import random +from collections.abc import MutableMapping +from typing import Any + +from sqlalchemy import select + +from datafaker.interactive import MissingnessCmd +from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin + + +class TestMissingnessCmd(MissingnessCmd, TestDbCmdMixin): + """MissingnessCmd but mocked""" + + +class ConfigureMissingnessTests(RequiresDBTestCase): + """Testing configure-missing.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + """We are using configure-missingness.""" + return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_set_missingness_to_sampled(self) -> None: + """Test that we can set one table to sampled missingness.""" + with self._get_cmd({}) as mc: + table = "signature_model" + mc.do_next(table) + mc.do_counts("") + self.assertSequenceEqual( + mc.messages, [(MissingnessCmd.ROW_COUNT_MSG, (10,), {})] + ) + # Check the counts of NULLs in each column + self.assertSequenceEqual(mc.rows, [["player_id", 4], ["based_on", 3]]) + mc.do_sampled("") + mc.do_quit("") + self.assertListEqual( + mc.config["tables"][table]["missingness_generators"], + [ + { + "columns": ["player_id", "based_on"], + "kwargs": { + "patterns": 'SRC_STATS["missing_auto__signature_model__0"]["results"]' + }, + "name": "column_presence.sampled", + } + ], + ) + self.assertEqual( + mc.config["src-stats"][0]["name"], + "missing_auto__signature_model__0", + ) + self.assertEqual( + mc.config["src-stats"][0]["query"], + ( + "SELECT COUNT(*) AS row_count," + " player_id__is_null, based_on__is_null FROM" + " (SELECT player_id IS NULL AS player_id__is_null," + " based_on IS NULL AS based_on__is_null FROM" + " signature_model ORDER BY RANDOM() LIMIT 1000)" + " AS __t GROUP BY player_id__is_null, based_on__is_null" + ), + ) + + +class ConfigureMissingnessTestsWithGeneration(GeneratesDBTestCase): + """Testing configure-missing with generation.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestMissingnessCmd: + return TestMissingnessCmd(self.dsn, self.schema_name, self.metadata, config) + + def test_create_with_missingness(self) -> None: + """Test that we can sample real missingness and reproduce it.""" + random.seed(45) + # Configure the missingness + table_name = "signature_model" + with self._get_cmd({}) as mc: + mc.do_next(table_name) + mc.do_sampled("") + mc.do_quit("") + config = mc.config + self.generate_data(config, num_passes=100) + # Test that each missingness pattern is present in the database + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).mappings().fetchall() + patterns: set[int] = set() + for row in rows: + p = 0 if row["player_id"] is None else 1 + b = 0 if row["based_on"] is None else 2 + patterns.add(p + b) + # all pattern possibilities should be present + self.assertSetEqual(patterns, {0, 1, 2, 3}) diff --git a/tests/test_interactive_table.py b/tests/test_interactive_table.py new file mode 100644 index 0000000..04b157e --- /dev/null +++ b/tests/test_interactive_table.py @@ -0,0 +1,398 @@ +""" Tests for the configure-tables command. """ +from collections.abc import MutableMapping +from typing import Any + +from sqlalchemy import select + +from datafaker.interactive import TableCmd +from tests.utils import RequiresDBTestCase, TestDbCmdMixin + + +class TestTableCmd(TableCmd, TestDbCmdMixin): + """TableCmd but mocked""" + + +class ConfigureTablesTests(RequiresDBTestCase): + """Testing configure-tables.""" + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestTableCmd: + return TestTableCmd(self.dsn, self.schema_name, self.metadata, config) + + +class ConfigureTablesSrcTests(ConfigureTablesTests): + """Testing configure-tables with src.dump.""" + + dump_file_path = "src.dump" + database_name = "src" + schema_name = "public" + + def test_table_name_prompts(self) -> None: + """Test that the prompts follow the names of the tables.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as tc: + table_names = list(self.metadata.tables.keys()) + for t in table_names: + self.assertIn(t, tc.prompt) + tc.do_next("") + self.assertListEqual(tc.messages, [(TableCmd.INFO_NO_MORE_TABLES, (), {})]) + tc.reset() + for t in reversed(table_names): + self.assertIn(t, tc.prompt) + tc.do_previous("") + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_ALREADY_AT_START, (), {})] + ) + tc.reset() + bad_table_name = "notarealtable" + tc.do_next(bad_table_name) + self.assertListEqual( + tc.messages, [(TableCmd.ERROR_NO_SUCH_TABLE, (bad_table_name,), {})] + ) + tc.reset() + good_table_name = table_names[2] + tc.do_next(good_table_name) + self.assertSequenceEqual(tc.messages, []) + self.assertIn(good_table_name, tc.prompt) + + def test_column_display(self) -> None: + """Test that we can see the names of the columns.""" + config: MutableMapping[str, Any] = {} + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_columns("") + self.assertSequenceEqual( + tc.rows, + [ + ["id", "INTEGER", True, False, ""], + ["a", "BOOLEAN", False, False, ""], + ["b", "BOOLEAN", False, False, ""], + ["c", "TEXT", False, False, ""], + ], + ) + + def test_null_configuration(self) -> None: + """A table still works if its configuration is None.""" + config = { + "tables": None, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_private("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) + + def test_null_table_configuration(self) -> None: + """A table still works if its configuration is None.""" + config = { + "tables": { + "unique_constraint_test": None, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_private("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertTrue( + tables["unique_constraint_test"].get("primary_private", False) + ) + + def test_configure_tables(self) -> None: + """Test that we can change columns to ignore, vocab or generate.""" + config = { + "tables": { + "unique_constraint_test": { + "vocabulary_table": True, + }, + "no_pk_test": { + "ignore": True, + }, + "hospital_visit": { + "num_passes": 0, + }, + "empty_vocabulary": { + "private": True, + }, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_generate("") + tc.do_next("person") + tc.do_vocabulary("") + tc.do_next("mitigation_type") + tc.do_ignore("") + tc.do_next("hospital_visit") + tc.do_private("") + tc.do_quit("") + tc.do_next("empty_vocabulary") + tc.do_empty("") + tc.do_quit("") + tables = tc.config["tables"] + self.assertFalse( + tables["unique_constraint_test"].get("vocabulary_table", False) + ) + self.assertFalse(tables["unique_constraint_test"].get("ignore", False)) + self.assertFalse( + tables["unique_constraint_test"].get("primary_private", False) + ) + self.assertEqual(tables["unique_constraint_test"].get("num_passes", 1), 1) + self.assertFalse(tables["no_pk_test"].get("vocabulary_table", False)) + self.assertTrue(tables["no_pk_test"].get("ignore", False)) + self.assertFalse(tables["no_pk_test"].get("primary_private", False)) + self.assertEqual(tables["no_pk_test"].get("num_rows_per_pass", 1), 1) + self.assertTrue(tables["person"].get("vocabulary_table", False)) + self.assertFalse(tables["person"].get("ignore", False)) + self.assertFalse(tables["person"].get("primary_private", False)) + self.assertEqual(tables["person"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["mitigation_type"].get("vocabulary_table", False)) + self.assertTrue(tables["mitigation_type"].get("ignore", False)) + self.assertFalse(tables["mitigation_type"].get("primary_private", False)) + self.assertEqual(tables["mitigation_type"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["hospital_visit"].get("vocabulary_table", False)) + self.assertFalse(tables["hospital_visit"].get("ignore", False)) + self.assertTrue(tables["hospital_visit"].get("primary_private", False)) + self.assertEqual(tables["hospital_visit"].get("num_rows_per_pass", 1), 1) + self.assertFalse(tables["empty_vocabulary"].get("vocabulary_table", False)) + self.assertFalse(tables["empty_vocabulary"].get("ignore", False)) + self.assertFalse(tables["empty_vocabulary"].get("primary_private", False)) + self.assertEqual(tables["empty_vocabulary"].get("num_rows_per_pass", 1), 0) + + def test_print_data(self) -> None: + """Test that we can print random rows from the table and random data from columns.""" + person_table = self.metadata.tables["person"] + with self.sync_engine.connect() as conn: + person_rows = conn.execute(select(person_table)).mappings().fetchall() + person_data = {row["person_id"]: row for row in person_rows} + name_set = {row["name"] for row in person_rows} + person_headings = ["person_id", "name", "research_opt_out", "stored_from"] + with self._get_cmd({}) as tc: + tc.do_next("person") + tc.do_data("") + self.assertSequenceEqual(tc.headings, person_headings) + self.assertEqual(len(tc.rows), 10) # default number of rows is 10 + for row in tc.rows: + expected = person_data[row[0]] + self.assertSequenceEqual(row, [expected[h] for h in person_headings]) + tc.reset() + rows_to_get_count = 6 + tc.do_data(str(rows_to_get_count)) + self.assertSequenceEqual(tc.headings, person_headings) + self.assertEqual(len(tc.rows), rows_to_get_count) + for row in tc.rows: + expected = person_data[row[0]] + self.assertSequenceEqual(row, [expected[h] for h in person_headings]) + tc.reset() + to_get_count = 12 + tc.do_data(f"{to_get_count} name") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual(len(tc.column_items[0]), to_get_count) + self.assertLessEqual(set(tc.column_items[0]), name_set) + tc.reset() + tc.do_data(f"{to_get_count} name 12") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual(len(tc.column_items[0]), to_get_count) + tc.reset() + tc.do_data(f"{to_get_count} name 13") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 13 <= len(n), name_set)) + ) + tc.reset() + tc.do_data(f"{to_get_count} name 16") + self.assertEqual(len(tc.column_items), 1) + self.assertEqual( + set(tc.column_items[0]), set(filter(lambda n: 16 <= len(n), name_set)) + ) + + def test_list_tables(self) -> None: + """Test that we can list the tables""" + config = { + "tables": { + "unique_constraint_test": { + "vocabulary_table": True, + }, + "no_pk_test": { + "ignore": True, + }, + }, + } + with self._get_cmd(config) as tc: + tc.do_next("unique_constraint_test") + tc.do_ignore("") + tc.do_next("person") + tc.do_vocabulary("") + tc.reset() + tc.do_tables("") + person_listed = False + unique_constraint_test_listed = False + no_pk_test_listed = False + for _text, args, _kwargs in tc.messages: + if args[2] == "person": + self.assertFalse(person_listed) + person_listed = True + self.assertEqual(args[0], "G") + self.assertEqual(args[1], "->V") + elif args[2] == "unique_constraint_test": + self.assertFalse(unique_constraint_test_listed) + unique_constraint_test_listed = True + self.assertEqual(args[0], "V") + self.assertEqual(args[1], "->I") + elif args[2] == "no_pk_test": + self.assertFalse(no_pk_test_listed) + no_pk_test_listed = True + self.assertEqual(args[0], "I") + self.assertEqual(args[1], " ") + else: + self.assertEqual(args[0], "G") + self.assertEqual(args[1], " ") + self.assertTrue(person_listed) + self.assertTrue(unique_constraint_test_listed) + self.assertTrue(no_pk_test_listed) + + +class ConfigureTablesInstrumentsTests(ConfigureTablesTests): + """Testing configure-tables with the instrument.sql database.""" + + dump_file_path = "instrument.sql" + database_name = "instrument" + schema_name = "public" + + def test_sanity_checks_both(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with self._get_cmd(config) as tc: + tc.reset() + tc.do_quit("") + self.assertEqual(tc.messages[0], (TableCmd.NOTE_TEXT_NO_CHANGES, (), {})) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) + self.assertEqual( + tc.messages[3], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[4], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) + + def test_sanity_checks_warnings_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + tc.do_next("manufacturer") + tc.do_vocabulary("") + tc.reset() + tc.do_quit("") + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("manufacturer", "ignore", "vocabulary"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_POTENTIAL_PROBLEMS, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_NON_EMPTY_TO_EMPTY, + ("signature_model", "player"), + {}, + ), + ) + + def test_sanity_checks_errors_only(self) -> None: + """ + Test ``configure-tables`` sanity checks. + """ + config = { + "tables": { + "model": { + "vocabulary_table": True, + }, + "manufacturer": { + "ignore": True, + }, + "player": { + "num_rows_per_pass": 0, + }, + }, + } + with TestTableCmd(self.dsn, self.schema_name, self.metadata, config) as tc: + tc.do_next("signature_model") + tc.do_empty("") + tc.reset() + tc.do_quit("") + self.assertEqual( + tc.messages[0], + ( + TableCmd.NOTE_TEXT_CHANGING, + ("signature_model", "generate", "empty"), + {}, + ), + ) + self.assertEqual( + tc.messages[1], (TableCmd.WARNING_TEXT_PROBLEMS_EXIST, (), {}) + ) + self.assertEqual( + tc.messages[2], + ( + TableCmd.WARNING_TEXT_VOCAB_TO_NON_VOCAB, + ("model", "manufacturer"), + {}, + ), + ) diff --git a/tests/test_main.py b/tests/test_main.py index e318f1e..1d5f59d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -398,9 +398,7 @@ def test_make_stats( self.assertSuccess(result) with open(example_conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - mock_make.assert_called_once_with( - get_test_settings().src_dsn, config, None - ) + mock_make.assert_called_once_with(get_test_settings().src_dsn, config, None) mock_path.return_value.write_text.assert_called_once_with( "a: 1\n", encoding="utf-8" ) diff --git a/tests/utils.py b/tests/utils.py index 78df87c..ab6f1d2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,6 +3,7 @@ import os import shutil import traceback +from collections.abc import MutableSequence, Sequence from functools import lru_cache from pathlib import Path from subprocess import run @@ -16,6 +17,7 @@ from datafaker import settings from datafaker.create import create_db_data_into +from datafaker.interactive.base import DbCmd from datafaker.make import make_src_stats, make_table_generators, make_tables_file from datafaker.remove import remove_db_data_from from datafaker.utils import ( @@ -264,12 +266,9 @@ def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" # `create-data` with all this stuff datafaker_module = import_file(self.generators_file_path) - table_generator_dict = datafaker_module.table_generator_dict - story_generator_list = datafaker_module.story_generator_list create_db_data_into( sorted_non_vocabulary_tables(self.metadata, config), - table_generator_dict, - story_generator_list, + datafaker_module, num_passes, self.dsn, self.schema_name, @@ -288,3 +287,45 @@ def generate_data( self.remove_data(config) self.create_data(config, num_passes) return src_stats + + +class TestDbCmdMixin(DbCmd): + """A mixin for capturing output from interactive commands.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize a TestDbCmdMixin""" + super().__init__(*args, **kwargs) + self.reset() + + def reset(self) -> None: + """Reset all the debug messages collected so far.""" + self.messages: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] + self.headings: Sequence[str] = [] + self.rows: Sequence[Sequence[str]] = [] + self.column_items: MutableSequence[Sequence[str]] = [] + self.columns: Mapping[str, Sequence[Any]] = {} + + def print(self, text: str, *args: Any, **kwargs: Any) -> None: + """Capture the printed message.""" + self.messages.append((text, args, kwargs)) + + def print_table( + self, headings: Sequence[str], rows: Sequence[Sequence[str]] + ) -> None: + """Capture the printed table.""" + self.headings = headings + self.rows = rows + + def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: + """Capture the printed table.""" + self.columns = columns + + # pylint: disable=arguments-renamed + def columnize(self, items: Sequence[str] | None, _displaywidth: int = 80) -> None: + """Capture the printed table.""" + if items is not None: + self.column_items.append(items) + + def ask_save(self) -> str: + """Quitting always works without needing to ask the user.""" + return "yes" From 2a4982f9ba9dca563e69bdb9a4f836c1c996d743 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 15 Oct 2025 18:15:12 +0100 Subject: [PATCH 21/35] Pre-commit cleaned. --- datafaker/generators/base.py | 7 +- datafaker/generators/choice.py | 1 + datafaker/generators/continuous.py | 106 ++- datafaker/generators/mimesis.py | 11 +- datafaker/generators/partitioned.py | 289 +++---- datafaker/interactive/__init__.py | 6 +- datafaker/interactive/base.py | 26 +- datafaker/interactive/generators.py | 90 +-- datafaker/interactive/table.py | 2 +- datafaker/make.py | 22 +- datafaker/utils.py | 60 +- tests/test_interactive_generators.py | 708 ++---------------- ...test_interactive_generators_partitioned.py | 419 +++++++++++ tests/test_noninteractive_generators.py | 179 +++++ 14 files changed, 1025 insertions(+), 901 deletions(-) create mode 100644 tests/test_interactive_generators_partitioned.py create mode 100644 tests/test_noninteractive_generators.py diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py index 1adcb9c..f2a1459 100644 --- a/datafaker/generators/base.py +++ b/datafaker/generators/base.py @@ -13,7 +13,7 @@ from typing_extensions import Self from datafaker.base import DistributionGenerator -from datafaker.utils import T, logger +from datafaker.utils import logger NumericType = Union[int, float] @@ -22,6 +22,10 @@ generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) +class GeneratorError(Exception): + """Error thrown from Datafaker Generators.""" + + class Generator(ABC): """ Random data generator. @@ -264,6 +268,7 @@ class Buckets: the fit of generators against it. """ + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, engine: Engine, diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py index 140b686..54f69d3 100644 --- a/datafaker/generators/choice.py +++ b/datafaker/generators/choice.py @@ -49,6 +49,7 @@ class ChoiceGenerator(Generator): STORE_COUNTS = False + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, table_name: str, diff --git a/datafaker/generators/continuous.py b/datafaker/generators/continuous.py index 42e4bcd..fc50c7f 100644 --- a/datafaker/generators/continuous.py +++ b/datafaker/generators/continuous.py @@ -1,8 +1,11 @@ """Generator factories for making generators of continuous distributions.""" -from typing import Any, Sequence +import itertools +from collections.abc import Iterable, Sequence +from typing import Any from sqlalchemy import Column, Engine, RowMapping, text +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.types import Integer, Numeric from datafaker.generators.base import ( @@ -13,7 +16,7 @@ dist_gen, get_column_type, ) -from datafaker.utils import logger +from datafaker.utils import Empty, logger class ContinuousDistributionGenerator(Generator): @@ -166,20 +169,26 @@ def get_generators( class LogNormalGenerator(Generator): """Generator producing numbers in a log-normal distribution.""" - # TODO: figure out the real buckets here (this was from a random sample in R) + # R: + # > xs<-seq(-2,2,0.5)*sqrt((exp(1)-1)*exp(1))+exp(0.5) + # > ys <- plnorm(xs) + # > c(ys, 1) - c(0,ys) + # [1] 0.00000000 0.00000000 0.00000000 0.28589471 0.40556775 0.15086088 + # [7] 0.06716451 0.03428958 0.01924848 0.03697409 expected_buckets = [ 0, 0, 0, - 0.28627, - 0.40607, - 0.14937, - 0.06735, - 0.03492, - 0.01918, - 0.03684, + 0.28589471, + 0.40556775, + 0.15086088, + 0.06716451, + 0.03428958, + 0.01924848, + 0.03697409, ] + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, table_name: str, @@ -290,6 +299,7 @@ def _get_generators_from_buckets( class MultivariateNormalGenerator(Generator): """Generator of multiple values drawn from a multivariate normal distribution.""" + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, table_name: str, @@ -359,11 +369,12 @@ def query_var(self, column: str) -> str: """Get the SQL expression of the value to query for this column.""" return column + # pylint: disable=too-many-arguments too-many-positional-arguments def query( self, table: str, - columns: list[Column], - predicates: list[str] = [], + columns: Sequence[Column], + predicates: Iterable[str] = Empty.iterable(), group_by_clause: str = "", constant_clauses: str = "", constants: str = "", @@ -385,17 +396,6 @@ def query( :param suppress_count: a group smaller than this will be suppressed. :param sample_count: this many samples will be taken from each partition. """ - preds = [self.query_predicate(col) for col in columns] + predicates - where = " WHERE " + " AND ".join(preds) if preds else "" - avgs = "".join( - f", AVG({self.query_var(col.name)}) AS m{i}" - for i, col in enumerate(columns) - ) - multiples = "".join( - f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" - for iy, coly in enumerate(columns) - for ix, colx in enumerate(columns[: iy + 1]) - ) means = "".join(f", _q.m{i}" for i in range(len(columns))) covs = "".join( ( @@ -405,20 +405,58 @@ def query( for iy in range(len(columns)) for ix in range(iy + 1) ) - if sample_count is None: - subquery = table + where - else: - subquery = ( - f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" - f" LIMIT {sample_count}) AS _sampled" - ) - # if there are any numeric columns we need at least# + subquery = self._inner_query(table, columns, predicates, sample_count) + # if there are any numeric columns we need at least # two rows to make any (co)variances at all suppress_clause = f" WHERE {suppress_count} < _q.count" if columns else "" return ( f"SELECT {len(columns)} AS rank{constant_clauses}, _q.count AS count{means}{covs}" - f" FROM (SELECT COUNT(*) AS count{multiples}{avgs}{constants}" - f" FROM {subquery}{group_by_clause}) AS _q{suppress_clause}" + f" FROM ({self._middle_query(columns, constants, subquery, group_by_clause)})" + f" AS _q{suppress_clause}" + ) + + def _inner_query( + self, + table: str, + columns: Sequence[Column], + predicates: Iterable[str], + sample_count: int | None, + ) -> str: + """Get the rows from the table that we are interested in.""" + preds = itertools.chain( + (self.query_predicate(col) for col in columns), + predicates, + ) + where = " AND ".join(preds) if preds else "" + if where: + where = " WHERE " + where + if sample_count is None: + return table + where + return ( + f"(SELECT * FROM {table}{where} ORDER BY RANDOM()" + f" LIMIT {sample_count}) AS _sampled" + ) + + def _middle_query( + self, + columns: Sequence[Column], + constants: str, + inner_query: str, + group_by_clause: str, + ) -> str: + """Get the basic statistics (and constants) from the inner query.""" + multiples = "".join( + f", SUM({self.query_var(colx.name)} * {self.query_var(coly.name)}) AS s{ix}_{iy}" + for iy, coly in enumerate(columns) + for ix, colx in enumerate(columns[: iy + 1]) + ) + avgs = "".join( + f", AVG({self.query_var(col.name)}) AS m{i}" + for i, col in enumerate(columns) + ) + return ( + f"SELECT COUNT(*) AS count{multiples}{avgs}{constants}" + f" FROM {inner_query}{group_by_clause}" ) def get_generators( @@ -439,7 +477,7 @@ def get_generators( with engine.connect() as connection: try: covariates = connection.execute(text(query)).mappings().first() - except Exception as e: + except SQLAlchemyError as e: logger.debug("SQL query %s failed with error %s", query, e) return [] if not covariates or covariates["c0_0"] is None: diff --git a/datafaker/generators/mimesis.py b/datafaker/generators/mimesis.py index 65c5d98..b300335 100644 --- a/datafaker/generators/mimesis.py +++ b/datafaker/generators/mimesis.py @@ -5,12 +5,14 @@ import mimesis import mimesis.locales from sqlalchemy import Column, Engine, text +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.types import Date, DateTime, Integer, Numeric, String, Time -from datafaker.base import DistributionGenerator from datafaker.generators.base import ( Buckets, + DistributionGenerator, Generator, + GeneratorError, GeneratorFactory, get_column_type, ) @@ -41,12 +43,12 @@ def __init__( f = generic for part in function_name.split("."): if not hasattr(f, part): - raise Exception( + raise GeneratorError( f"Mimesis does not have a function {function_name}: {part} not found" ) f = getattr(f, part) if not callable(f): - raise Exception( + raise GeneratorError( f"Mimesis object {function_name} is not a callable," " so cannot be used as a generator" ) @@ -152,6 +154,7 @@ def generate_data(self, count: int) -> list[Any]: class MimesisDateTimeGenerator(MimesisGeneratorBase): """DateTime generator using Mimesis.""" + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, column: Column, @@ -306,7 +309,7 @@ def get_generators( f"LENGTH({column.name})", ) fitness_fn = len - except Exception: + except SQLAlchemyError: # Some column types that appear to be strings (such as enums) # cannot have their lengths measured. In this case we cannot # detect fitness using lengths. diff --git a/datafaker/generators/partitioned.py b/datafaker/generators/partitioned.py index 93a7880..f14af73 100644 --- a/datafaker/generators/partitioned.py +++ b/datafaker/generators/partitioned.py @@ -80,6 +80,43 @@ def comment(self) -> str: ) +@dataclass +class NullableColumn: + """A reference to a nullable column whose nullability is part of a partitioning.""" + + column: Column + # The bit (power of two) of the number of the partition in the partition sizes list + bitmask: int + + +class PartitionCountQuery: + """Query, result and comment for the row counts of the null pattern partitions.""" + + def __init__( + self, + connection: Connection, + query: str, + table_name: str, + nullable_columns: Iterable[NullableColumn], + ) -> None: + """ + Initialise the partition count query. + + :param connection: Database connection. + :param query: The query getting the row counts of the null pattern partitions. + :param table_name: The name of the table being queried. + :param nullable_columns: The columns that are being checked for nullness. + """ + self.query = query + rows = connection.execute(text(query)).mappings().fetchall() + self.results = [dict(row) for row in rows] + self.comment = ( + "Number of rows for each combination of the columns" + f" { {nc.column.name for nc in nullable_columns} }" + f" of the table {table_name} being null" + ) + + class NullPartitionedNormalGenerator(Generator): """ A generator of mixed numeric and non-numeric data. @@ -97,23 +134,20 @@ class NullPartitionedNormalGenerator(Generator): rows are used because no covariate matrix is required for this). """ + # pylint: disable=too-many-arguments too-many-positional-arguments def __init__( self, query_name: str, partitions: dict[int, RowPartition], function_name: str = "grouped_multivariate_lognormal", name_suffix: str | None = None, - partition_count_query: str | None = None, - partition_counts: Iterable[RowMapping] = [], - partition_count_comment: str | None = None, + partition_count_query: PartitionCountQuery | None = None, ): """Initialise a NullPartitionedNormalGenerator.""" self._query_name = query_name self._partitions = partitions self._function_name = function_name self._partition_count_query = partition_count_query - self._partition_counts = [dict(pc) for pc in partition_counts] - self._partition_count_comment = partition_count_comment if name_suffix: self._name = f"null-partitioned {function_name} [{name_suffix}]" else: @@ -186,8 +220,8 @@ def custom_queries(self) -> dict[str, Any]: return partitions return { self._count_query_name(): { - "comment": self._partition_count_comment, - "query": self._partition_count_query, + "comment": self._partition_count_query.comment, + "query": self._partition_count_query.query, }, **partitions, } @@ -223,12 +257,16 @@ def _actual_kwargs_with_combinations( def actual_kwargs(self) -> dict[str, Any]: """Get the kwargs (summary statistics) this generator was instantiated with.""" + if self._partition_count_query is None: + counts = None + else: + counts = self._partition_count_query.results return { "alternative_configs": [ self._actual_kwargs_with_combinations(self._partitions[index]) for index in range(len(self._partitions)) ], - "counts": self._partition_counts, + "counts": counts, } def generate_data(self, count: int) -> list[Any]: @@ -252,15 +290,7 @@ def powerset(xs: list[T]) -> Iterable[Iterable[T]]: return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) -@dataclass -class NullableColumn: - """A reference to a nullable column whose nullability is part of a partitioning.""" - - column: Column - # The bit (power of two) of the number of the partition in the partition sizes list - bitmask: int - - +# pylint: disable=too-many-instance-attributes class NullPatternPartition: """Get the definition of a partition (in other words, what makes it not another partition).""" @@ -362,6 +392,70 @@ def get_partition_count_query( f' FROM {table} GROUP BY "index") AS _q {where}' ) + def _get_row_partition( + self, + table: str, + partition: NullPatternPartition, + suppress_count: int = 1, + sample_count: int | None = None, + ) -> RowPartition: + """Get the RowPartition from a NullPatternPartition.""" + query = self.query( + table=table, + columns=partition.included_numeric, + predicates=partition.predicates, + group_by_clause=partition.group_by_clause, + constants=partition.constants, + constant_clauses=partition.constant_clauses, + suppress_count=suppress_count, + sample_count=sample_count, + ) + return RowPartition( + query, + partition.included_numeric, + partition.included_choice, + partition.excluded, + partition.nones, + [], + ) + + # pylint: disable=too-many-arguments too-many-positional-arguments + def _get_generator( + self, + connection: Connection, + table_name: str, + columns: list[Column], + nullable_columns: list[NullableColumn], + where: str | None = None, + name_suffix: str | None = None, + suppress_count: int = 1, + sample_count: int | None = None, + ) -> NullPartitionedNormalGenerator | None: + query = self.get_partition_count_query(nullable_columns, table_name, where) + partitions: dict[int, RowPartition] = {} + for partition_nonnulls in powerset(nullable_columns): + partition_def = NullPatternPartition(columns, partition_nonnulls) + partitions[partition_def.index] = self._get_row_partition( + table_name, + partition_def, + suppress_count=suppress_count, + sample_count=sample_count, + ) + if not self._execute_partition_queries(connection, partitions): + return None + return NullPartitionedNormalGenerator( + f"{table_name}__{columns[0].name}", + partitions, + self.function_name(), + name_suffix=name_suffix, + partition_count_query=PartitionCountQuery( + connection, + query, + table_name, + nullable_columns, + ), + ) + def get_generators( self, columns: list[Column], engine: Engine ) -> Sequence[Generator]: @@ -372,139 +466,54 @@ def get_generators( if not nullable_columns: return [] table = columns[0].table.name - query_name = f"{table}__{columns[0].name}" - # Partitions for minimal suppression and no sampling - row_partitions_maximal: dict[int, RowPartition] = {} - # Partitions for minimal suppression but sampling - row_partitions_sampled: dict[int, RowPartition] = {} - # Partitions for normal suppression and severe sampling - row_partitions_ss: dict[int, RowPartition] = {} - for partition_nonnulls in powerset(nullable_columns): - partition_def = NullPatternPartition(columns, partition_nonnulls) - query_all = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - ) - row_partitions_maximal[partition_def.index] = RowPartition( - query_all, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - query_sampled = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_sampled[partition_def.index] = RowPartition( - query_sampled, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - {}, - ) - query_ss = self.query( - table=table, - columns=partition_def.included_numeric, - predicates=partition_def.predicates, - group_by_clause=partition_def.group_by_clause, - constants=partition_def.constants, - constant_clauses=partition_def.constant_clauses, - suppress_count=self.SUPPRESS_COUNT, - sample_count=self.SAMPLE_COUNT, - ) - row_partitions_ss[partition_def.index] = RowPartition( - query_ss, - partition_def.included_numeric, - partition_def.included_choice, - partition_def.excluded, - partition_def.nones, - [], - ) - gens: list[Generator] = [] + gens: list[Generator | None] = [] try: with engine.connect() as connection: - partition_query_max = self.get_partition_count_query( - nullable_columns, table - ) - partition_count_max_results = ( - connection.execute(text(partition_query_max)).mappings().fetchall() - ) - count_comment = ( - "Number of rows for each combination of the columns" - f" { {nc.column.name for nc in nullable_columns} }" - f" of the table {table} being null" - ) - if self._execute_partition_queries(connection, row_partitions_maximal): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_maximal, - self.function_name(), - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) - ) - if self._execute_partition_queries(connection, row_partitions_sampled): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_sampled, - self.function_name(), - name_suffix="sampled", - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, ) - if self._execute_partition_queries(connection, row_partitions_sampled): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_sampled, - self.function_name(), - name_suffix="sampled", - partition_count_query=partition_query_max, - partition_counts=partition_count_max_results, - partition_count_comment=count_comment, - ) + ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + name_suffix="sampled", + sample_count=self.SAMPLE_COUNT, ) - partition_query_ss = self.get_partition_count_query( - nullable_columns, - table, - where=f"WHERE {self.SUPPRESS_COUNT} < count", ) - partition_count_ss_results = ( - connection.execute(text(partition_query_ss)).mappings().fetchall() + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + name_suffix="sampled and suppressed", + suppress_count=self.SUPPRESS_COUNT, + sample_count=self.SAMPLE_COUNT, + ) ) - if self._execute_partition_queries(connection, row_partitions_ss): - gens.append( - NullPartitionedNormalGenerator( - query_name, - row_partitions_ss, - self.function_name(), - name_suffix="sampled and suppressed", - partition_count_query=partition_query_ss, - partition_counts=partition_count_ss_results, - partition_count_comment=count_comment, - ) + gens.append( + self._get_generator( + connection, + table, + columns, + nullable_columns, + where=f"WHERE {self.SUPPRESS_COUNT} < count", + name_suffix="suppressed", + suppress_count=self.SUPPRESS_COUNT, ) + ) except sqlalchemy.exc.DatabaseError as exc: logger.debug("SQL query failed with error %s [%s]", exc, exc.statement) return [] - return gens + return [gen for gen in gens if gen] def _execute_partition_queries( self, diff --git a/datafaker/interactive/__init__.py b/datafaker/interactive/__init__.py index 952eadf..c279720 100644 --- a/datafaker/interactive/__init__.py +++ b/datafaker/interactive/__init__.py @@ -20,7 +20,7 @@ if not hasattr(readline, "backend"): setattr(readline, "backend", "readline") -except: +except ImportError: pass @@ -86,7 +86,7 @@ def update_config_generators( if line: if len(line) < 3: logger.error( - "line {0} of file {1} has fewer than three values", + "line %d of file %s has fewer than three values", line_no, spec_path, ) @@ -95,6 +95,6 @@ def update_config_generators( if len(cols) == 1 or gc.set_merged_columns(cols[0], cols[1]): try_setting_generator(gc, itertools.islice(line, 2, None)) else: - logger.warning("no such column {0}[{1}]", line[0], line[1]) + logger.warning("no such column %s[%s]", line[0], line[1]) gc.do_quit("yes") return gc.config diff --git a/datafaker/interactive/base.py b/datafaker/interactive/base.py index 51793fe..9d612a7 100644 --- a/datafaker/interactive/base.py +++ b/datafaker/interactive/base.py @@ -100,7 +100,10 @@ class DbCmd(ABC, cmd.Cmd): INFO_NO_MORE_TABLES = "There are no more tables" ERROR_ALREADY_AT_START = "Error: Already at the start" ERROR_NO_SUCH_TABLE = "Error: '{0}' is not the name of a table in this database" - ERROR_NO_SUCH_TABLE_OR_COLUMN = "Error: '{0}' is not the name of a table in this database or a column in this table" + ERROR_NO_SUCH_TABLE_OR_COLUMN = ( + "Error: '{0}' is not the name of a table" + " in this database or a column in this table" + ) ROW_COUNT_MSG = "Total row count: {}" @abstractmethod @@ -185,7 +188,7 @@ def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None: :param columns: Dict of column names to the values in the column. """ output = PrettyTable() - row_count = max([len(col) for col in columns.values()]) + row_count = max(len(col) for col in columns.values()) for field_name, data in columns.items(): output.add_column(field_name, list(data) + [None] * (row_count - len(data))) print(output) @@ -207,7 +210,6 @@ def ask_save(self) -> str: @abstractmethod def set_prompt(self) -> None: """Set the prompt according to the current state.""" - ... def _set_table_index(self, index: int) -> bool: """ @@ -325,21 +327,25 @@ def do_counts(self, _arg: str) -> None: nonnull_columns = self.get_nonnull_columns(table_name) colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns] with self.sync_engine.connect() as connection: - result = connection.execute( - sqlalchemy.text( - f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + result = ( + connection.execute( + sqlalchemy.text( + f"SELECT COUNT(*) AS row_count{''.join(colcounts)} FROM {table_name}" + ) ) - ).first() + .mappings() + .first() + ) if result is None: self.print("Could not count rows in table {0}", table_name) return - row_count = result.row_count + row_count = result.get("row_count", 0) self.print(self.ROW_COUNT_MSG, row_count) self.print_table( ["Column", "NULL count"], [ [name, row_count - count] - for name, count in result._mapping.items() + for name, count in result.items() if name != "row_count" ], ) @@ -388,7 +394,7 @@ def do_peek(self, arg: str) -> None: ) try: result = connection.execute(query) - except Exception as exc: + except sqlalchemy.exc.SQLAlchemyError as exc: self.print(f'SQL query "{query}" caused exception {exc}') return self.print_table(list(result.keys()), result.fetchmany(max_peek_rows)) diff --git a/datafaker/interactive/generators.py b/datafaker/interactive/generators.py index 544e0cf..24ccb52 100644 --- a/datafaker/interactive/generators.py +++ b/datafaker/interactive/generators.py @@ -1,4 +1,4 @@ -"""Generator configuration shell.""" +"""Generator configuration shell.""" # pylint: disable=too-many-lines import functools import re from collections.abc import Iterable, Mapping, MutableMapping, Sequence @@ -11,7 +11,13 @@ from datafaker.generators import everything_factory from datafaker.generators.base import Generator, PredefinedGenerator from datafaker.interactive.base import DbCmd, TableEntry, fk_column_name, or_default -from datafaker.utils import logger, primary_private_fks, table_is_private +from datafaker.utils import ( + get_columns_assigned, + get_row_generators, + logger, + primary_private_fks, + table_is_private, +) @dataclass @@ -35,6 +41,7 @@ class GeneratorCmdTableEntry(TableEntry): new_generators: list[GeneratorInfo] +# pylint: disable=too-many-public-methods class GeneratorCmd(DbCmd): """Interactive command shell for setting generators.""" @@ -85,50 +92,41 @@ def make_table_entry( return None if table_config.get("num_rows_per_pass", 1) == 0: return None - metadata_table = self.metadata.tables[table_name] - columns = [str(colname) for colname in metadata_table.columns.keys()] + columns = [ + str(colname) for colname in self.metadata.tables[table_name].columns.keys() + ] column_set = frozenset(columns) columns_assigned_so_far: set[str] = set() new_generator_infos: list[GeneratorInfo] = [] - old_generator_infos: list[GeneratorInfo] = [] - for rg in table_config.get("row_generators", []): - gen_name = rg.get("name", None) - if gen_name: - ca = rg.get("columns_assigned", []) - collist: list[str] = ( - [ca] if isinstance(ca, str) else [str(c) for c in ca] + for gen_name, rg in get_row_generators(table_config): + colset: set[str] = set(get_columns_assigned(rg)) + for unknown in colset - column_set: + logger.warning( + "table '%s' has '%s' assigned to column '%s' which is not in this table", + table_name, + gen_name, + unknown, ) - colset: set[str] = set(collist) - for unknown in colset - column_set: - logger.warning( - "table '%s' has '%s' assigned to column '%s' which is not in this table", - table_name, - gen_name, - unknown, - ) - for mult in columns_assigned_so_far & colset: - logger.warning( - "table '%s' has column '%s' assigned to multiple times", - table_name, - mult, - ) - actual_collist = [c for c in collist if c in columns] - if actual_collist: - gen = PredefinedGenerator(table_name, rg, self.config) - new_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) - ) - old_generator_infos.append( - GeneratorInfo( - columns=actual_collist.copy(), - gen=gen, - ) + for mult in columns_assigned_so_far & colset: + logger.warning( + "table '%s' has column '%s' assigned to multiple times", + table_name, + mult, + ) + actual_collist = [c for c in columns if c in colset] + if actual_collist: + new_generator_infos.append( + GeneratorInfo( + columns=actual_collist.copy(), + gen=PredefinedGenerator(table_name, rg, self.config), ) - columns_assigned_so_far |= colset + ) + columns_assigned_so_far |= colset + old_generator_infos = [ + GeneratorInfo(columns=gi.columns.copy(), gen=gi.gen) + for gi in new_generator_infos + ] for colname in columns: if colname not in columns_assigned_so_far: new_generator_infos.append( @@ -139,6 +137,7 @@ def make_table_entry( ) if len(new_generator_infos) == 0: return None + return GeneratorCmdTableEntry( name=table_name, old_generators=old_generator_infos, @@ -853,7 +852,7 @@ def do_unset(self, _arg: str) -> None: self.set_generator(None) self._go_next() - def merge_columns(self, arg: str) -> None: + def merge_columns(self, arg: str) -> bool: """ Add this column(s) to the specified column(s). @@ -877,8 +876,7 @@ def merge_columns(self, arg: str) -> None: self.print(self.ERROR_NO_SUCH_COLUMN, uc) return False gen_info = table_entry.new_generators[self.generator_index] - current_columns = frozenset(gen_info.columns) - stated_current_columns = cols_to_merge & current_columns + stated_current_columns = cols_to_merge & frozenset(gen_info.columns) if stated_current_columns: for c in stated_current_columns: self.print(self.ERROR_COLUMN_ALREADY_MERGED, c) @@ -911,7 +909,7 @@ def merge_columns(self, arg: str) -> None: self.set_prompt() return True - def do_merge(self, arg: str): + def do_merge(self, arg: str) -> None: """Add this column(s) to the specified column(s), so one generator covers them all.""" self.merge_columns(arg) @@ -987,7 +985,9 @@ def complete_unmerge( def get_current_columns(self) -> set[str]: """Get the current colums.""" - table_entry: GeneratorCmdTableEntry = self.get_table() + table_entry: GeneratorCmdTableEntry | None = self.get_table() + if table_entry is None: + return set() gen_info = table_entry.new_generators[self.generator_index] return set(gen_info.columns) diff --git a/datafaker/interactive/table.py b/datafaker/interactive/table.py index 40301b0..d763a14 100644 --- a/datafaker/interactive/table.py +++ b/datafaker/interactive/table.py @@ -1,5 +1,5 @@ """Table configuration command shell.""" -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass from typing import Any, cast diff --git a/datafaker/make.py b/datafaker/make.py index bca7a2b..6f4cc9b 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -28,9 +28,11 @@ MaybeAsyncEngine, create_db_engine, download_table, + get_columns_assigned, get_flag, get_property, get_related_table_names, + get_row_generators, get_sync_engine, get_vocabulary_table_names, logger, @@ -176,27 +178,15 @@ def _get_row_generator( ) -> tuple[list[RowGeneratorInfo], list[str]]: """Get the row generators information, for the given table.""" row_gen_info: list[RowGeneratorInfo] = [] - config: list[Mapping[str, Any]] = get_property(table_config, "row_generators", []) columns_covered = [] - for gen_conf in config: - name: str = gen_conf["name"] - columns_assigned = gen_conf["columns_assigned"] + for name, gen_conf in get_row_generators(table_config): + columns_assigned = list(get_columns_assigned(gen_conf)) keyword_arguments: Mapping[str, Any] = gen_conf.get("kwargs", {}) positional_arguments: Sequence[str] = gen_conf.get("args", []) - - if isinstance(columns_assigned, str): - columns_assigned = [columns_assigned] - - variable_names: list[str] = columns_assigned - try: - columns_covered += columns_assigned - except TypeError: - # Might be a single string, rather than a list of strings. - columns_covered.append(columns_assigned) - + columns_covered += columns_assigned row_gen_info.append( RowGeneratorInfo( - variable_names=variable_names, + variable_names=columns_assigned, function_call=_get_function_call( name, positional_arguments, keyword_arguments ), diff --git a/datafaker/utils.py b/datafaker/utils.py index 3cc8c28..7ef91bf 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -9,7 +9,17 @@ from collections.abc import Mapping, Sequence from pathlib import Path from types import ModuleType -from typing import Any, Callable, Final, Generator, Iterable, Optional, TypeVar, Union +from typing import ( + Any, + Callable, + Final, + Generator, + Generic, + Iterable, + Optional, + TypeVar, + Union, +) import psycopg2 import sqlalchemy @@ -43,6 +53,16 @@ T = TypeVar("T") +class Empty(Generic[T]): + """Generic empty sequences for default arguments.""" + + @classmethod + def iterable(cls) -> Iterable[T]: + """Get an empty iterable.""" + e: list[T] = [] + return (x for x in e) + + def read_config_file(path: str) -> dict: """Read a config file, warning if it is invalid. @@ -417,6 +437,44 @@ def get_vocabulary_table_names(config: Mapping) -> set[str]: } +def get_columns_assigned( + row_generator_config: Mapping[str, Any] +) -> Generator[str, None, None]: + """ + Get the columns assigned in a ``row_generators[n]`` stanza. + + :param generator_config: The ``row_generators[n]`` stanza itself. + """ + ca = row_generator_config.get("columns_assigned", None) + if ca is None: + return + if isinstance(ca, str): + yield ca + return + if not hasattr(ca, "__iter__"): + return + for c in ca: + yield str(c) + + +def get_row_generators( + table_config: Mapping[str, Any], +) -> Generator[tuple[str, Mapping[str, Any]], None, None]: + """ + Get the row generators from a table configuration. + + :param table_config: The element from the ``tables:`` stanza of ``config.xml``. + :return: Pair of (name, row generator config). + """ + rgs = table_config.get("row_generators", None) + if isinstance(rgs, str) or not hasattr(rgs, "__iter__"): + return + for rg in rgs: + name = rg.get("name", None) + if name: + yield (name, rg) + + def make_foreign_key_name(table_name: str, col_name: str) -> str: """Make a suitable foreign key name.""" return f"{table_name}_{col_name}_fkey" diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py index 6b4da13..a7eb757 100644 --- a/tests/test_interactive_generators.py +++ b/tests/test_interactive_generators.py @@ -2,16 +2,11 @@ import copy import re from collections.abc import MutableMapping -from dataclasses import dataclass from typing import Any, Iterable -from unittest import TestCase -from unittest.mock import MagicMock, Mock, patch -from sqlalchemy import Connection, MetaData, insert, select +from sqlalchemy import Connection, MetaData, select -from datafaker.generators import NullPartitionedNormalGeneratorFactory from datafaker.generators.choice import ChoiceGeneratorFactory -from datafaker.interactive import update_config_generators from datafaker.interactive.generators import GeneratorCmd from tests.utils import GeneratesDBTestCase, RequiresDBTestCase, TestDbCmdMixin @@ -565,6 +560,22 @@ def test_empty_tables_are_not_configured(self) -> None: self.assertNotIn("string", table_names) +class ChoiceMeasurementTableStats: + """Measure the data in the ``choice.sql`` schema.""" + + def __init__(self, metadata: MetaData, connection: Connection): + """Get the data and do the analysis.""" + stmt = select(metadata.tables["number_table"]) + rows = connection.execute(stmt).fetchall() + self.ones: set[int] = set() + self.twos: set[int] = set() + self.threes: set[int] = set() + for row in rows: + self.ones.add(row.one) + self.twos.add(row.two) + self.threes.add(row.three) + + class GeneratorsOutputTests(GeneratesDBTestCase): """Testing choice generation.""" @@ -580,14 +591,16 @@ def setUp(self) -> None: def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: + gc.reset() + gc.do_propose("") + return gc.get_proposals() + def test_create_with_sampled_choice(self) -> None: """Test that suppression works for choice and zipf_choice.""" - table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) self.assertIn("dist_gen.choice", proposals) self.assertIn("dist_gen.zipf_choice", proposals) self.assertIn("dist_gen.choice [sampled]", proposals) @@ -596,9 +609,7 @@ def test_create_with_sampled_choice(self) -> None: self.assertIn("dist_gen.zipf_choice [sampled and suppressed]", proposals) gc.do_set(str(proposals["dist_gen.choice [sampled and suppressed]"][0])) gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) self.assertIn("dist_gen.choice", proposals) self.assertIn("dist_gen.zipf_choice", proposals) self.assertIn("dist_gen.choice [sampled]", proposals) @@ -609,9 +620,7 @@ def test_create_with_sampled_choice(self) -> None: str(proposals["dist_gen.zipf_choice [sampled and suppressed]"][0]) ) gc.do_next("number_table.three") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) self.assertIn("dist_gen.choice", proposals) self.assertIn("dist_gen.zipf_choice", proposals) self.assertIn("dist_gen.choice [sampled]", proposals) @@ -621,34 +630,22 @@ def test_create_with_sampled_choice(self) -> None: gc.do_set(str(proposals["dist_gen.choice [sampled]"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.sync_engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - ones = set() - twos = set() - threes = set() - for row in rows: - ones.add(row.one) - twos.add(row.two) - threes.add(row.three) # all generation possibilities should be present - self.assertSetEqual(ones, {1, 4}) - self.assertSetEqual(twos, {2, 3}) - self.assertSetEqual(threes, {1, 2, 3, 4, 5}) + with self.sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {2, 3}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) def test_create_with_choice(self) -> None: """Smoke test normal choice works.""" table_name = "number_table" with self._get_cmd({}) as gc: gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) gc.do_set(str(proposals["dist_gen.choice"][0])) gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() + proposals = self._propose(gc) gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) @@ -668,13 +665,14 @@ def test_create_with_weighted_choice(self) -> None: """Smoke test weighted choice.""" with self._get_cmd({}) as gc: gc.do_next("number_table.one") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn( - "dist_gen.weighted_choice [sampled and suppressed]", proposals + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [sampled and suppressed]", + }, + set(proposals), ) prop = proposals["dist_gen.weighted_choice [sampled and suppressed]"] self.assert_subset(set(prop[2]), {"1", "4"}) @@ -688,13 +686,14 @@ def test_create_with_weighted_choice(self) -> None: self.assert_subset(col_set, {1, 4}) gc.do_set(str(prop[0])) gc.do_next("number_table.two") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) - self.assertIn( - "dist_gen.weighted_choice [sampled and suppressed]", proposals + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [sampled and suppressed]", + }, + set(proposals), ) prop = proposals["dist_gen.weighted_choice"] self.assert_subset(set(prop[2]), {"1", "2", "3", "4", "5"}) @@ -706,11 +705,14 @@ def test_create_with_weighted_choice(self) -> None: self.assert_subset(col_set2, {1, 2, 3, 4, 5}) gc.do_set(str(prop[0])) gc.do_next("number_table.three") - gc.reset() - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("dist_gen.weighted_choice", proposals) - self.assertIn("dist_gen.weighted_choice [sampled]", proposals) + proposals = self._propose(gc) + self.assert_subset( + { + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + }, + set(proposals), + ) self.assertNotIn( "dist_gen.weighted_choice [sampled and suppressed]", proposals ) @@ -725,19 +727,12 @@ def test_create_with_weighted_choice(self) -> None: gc.do_quit("") self.generate_data(gc.config, num_passes=200) with self.sync_engine.connect() as conn: - ones = set() - twos = set() - threes = set() - for row in conn.execute( - select(self.metadata.tables["number_table"]) - ).fetchall(): - ones.add(row.one) - twos.add(row.two) - threes.add(row.three) - # all generation possibilities should be present - self.assertSetEqual(ones, {1, 4}) - self.assertSetEqual(twos, {1, 2, 3, 4, 5}) - self.assertSetEqual(threes, {1, 2, 3, 4, 5}) + with self.sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + # all generation possibilities should be present + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) class GeneratorTests(GeneratesDBTestCase): @@ -864,582 +859,3 @@ def test_varchar_ns_are_truncated(self) -> None: stmt = select(self.metadata.tables[table].c[column]) rows = conn.execute(stmt).scalars().fetchall() self.assert_are_truncated_to(rows, 20) - - -@dataclass -class Stat: - """Mean and variance calculator.""" - - n: int = 0 - x: float = 0 - x2: float = 0 - - def add(self, x: float) -> None: - """Add one datum.""" - self.n += 1 - self.x += x - self.x2 += x * x - - def count(self) -> int: - """Get the number of data added.""" - return self.n - - def x_mean(self) -> float: - """Get the mean of the added data.""" - return self.x / self.n - - def x_var(self) -> float: - """Get the variance of the added data.""" - x = self.x - return (self.x2 - x * x / self.n) / (self.n - 1) - - -@dataclass -class Correlation(Stat): - """Mean, variance and covariance.""" - - y: float = 0 - y2: float = 0 - xy: float = 0 - - def add2(self, x: float, y: float) -> None: - """Add a 2D data point.""" - self.n += 1 - self.x += x - self.x2 += x * x - self.y += y - self.y2 += y * y - self.xy += x * y - - def y_mean(self) -> float: - """Get the mean of the second parts of the added points.""" - return self.y / self.n - - def y_var(self) -> float: - """Get the variance of the second parts of the added points.""" - y = self.y - return (self.y2 - y * y / self.n) / (self.n - 1) - - def covar(self) -> float: - """Get the covariance of the two parts of the added points.""" - return (self.xy - self.x * self.y / self.n) / (self.n - 1) - - -# pylint disable: too-many-instance-attributes -class EavMeasurementTableStats: - """The statistics for the Measurement table of eav.sql.""" - - def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: - stmt = select(metadata.tables["measurement"]) - rows = conn.execute(stmt).fetchall() - self.types: set[int] = set() - self.one_count = 0 - self.one_yes_count = 0 - self.two = Correlation() - self.three = Correlation() - self.four = Correlation() - self.fish = Stat() - self.fowl = Stat() - for row in rows: - self.types.add(row.type) - if row.type == 1: - # yes or no - test.assertIsNone(row.first_value) - test.assertIsNone(row.second_value) - test.assertIn(row.third_value, {"yes", "no"}) - self.one_count += 1 - if row.third_value == "yes": - self.one_yes_count += 1 - elif row.type == 2: - # positive correlation around 1.4, 1.8 - test.assertIsNotNone(row.first_value) - test.assertIsNotNone(row.second_value) - test.assertIsNone(row.third_value) - self.two.add2(row.first_value, row.second_value) - elif row.type == 3: - # negative correlation around 11.8, 12.1 - test.assertIsNotNone(row.first_value) - test.assertIsNotNone(row.second_value) - test.assertIsNone(row.third_value) - self.three.add2(row.first_value, row.second_value) - elif row.type == 4: - # positive correlation around 21.4, 23.4 - test.assertIsNotNone(row.first_value) - test.assertIsNotNone(row.second_value) - test.assertIsNone(row.third_value) - self.four.add2(row.first_value, row.second_value) - elif row.type == 5: - test.assertIn(row.third_value, {"fish", "fowl"}) - test.assertIsNotNone(row.first_value) - test.assertIsNone(row.second_value) - if row.third_value == "fish": - # mean 8.1 and sd 0.755 - self.fish.add(row.first_value) - else: - # mean 11.2 and sd 1.114 - self.fowl.add(row.first_value) - - -class NullPartitionedTests(GeneratesDBTestCase): - """Testing null-partitioned grouped multivariate generation.""" - - dump_file_path = "eav.sql" - database_name = "eav" - schema_name = "public" - - def setUp(self) -> None: - """Set up the test with specific sample and suppress counts.""" - super().setUp() - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 - NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 - - def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: - """Get the configure-generators object as our command.""" - return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) - - def test_create_with_null_partitioned_grouped_multivariate(self) -> None: - """Test EAV for all columns.""" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns( - gc, - "measurement", - [ - "type", - "first_value", - "second_value", - "third_value", - ], - ) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - dist_to_choose = "null-partitioned grouped_multivariate_normal" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - gc.reset() - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: - stats = EavMeasurementTableStats(conn, self.metadata, self) - # type 1 - self.assertAlmostEqual( - stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 - ) - # type 2 - self.assertAlmostEqual( - stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 - ) - self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) - self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) - self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) - self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) - # type 3 - self.assertAlmostEqual( - stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) - # type 4 - self.assertAlmostEqual( - stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) - # type 5/fish - self.assertAlmostEqual( - stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) - # type 5/fowl - self.assertAlmostEqual( - stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) - - def populate_measurement_type_vocab(self): - """Add a vocab table without messing around with files""" - table = self.metadata.tables["measurement_type"] - with self.engine.connect() as conn: - conn.execute(insert(table).values({"id": 1, "name": "agreement"})) - conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) - conn.execute(insert(table).values({"id": 3, "name": "velocity"})) - conn.execute(insert(table).values({"id": 4, "name": "position"})) - conn.execute(insert(table).values({"id": 5, "name": "matter"})) - conn.commit() - - def merge_columns( - self, gc: TestGeneratorCmd, table: str, columns: list[str] - ) -> None: - """Merge columns in a table""" - gc.do_next(f"{table}.{columns[0]}") - for col in columns[1:]: - gc.do_merge(col) - gc.reset() - - def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: - """Test EAV for all columns with sampled and suppressed generation.""" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns( - gc, - "measurement", - [ - "type", - "first_value", - "second_value", - "third_value", - ], - ) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn( - "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", - proposals, - ) - dist_to_choose = ( - "null-partitioned grouped_multivariate_normal [sampled and suppressed]" - ) - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - self.merge_columns( - gc, - "observation", - [ - "type", - "first_value", - "second_value", - "third_value", - ], - ) - gc.do_propose("") - proposals = gc.get_proposals() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: - stats = EavMeasurementTableStats(conn, self.metadata, self) - stmt = select(self.metadata.tables["observation"]) - rows = conn.execute(stmt).fetchall() - firsts = Stat() - for row in rows: - stats.types.add(row.type) - self.assertEqual(row.type, 1) - self.assertIsNotNone(row.first_value) - self.assertIsNone(row.second_value) - self.assertIn(row.third_value, {"ham", "eggs"}) - firsts.add(row.first_value) - self.assertEqual(firsts.count(), 800) - self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) - self.assert_subset(stats.types, {1, 2, 3, 4, 5}) - self.assertEqual(len(stats.types), 4) - self.assert_subset({1, 5}, stats.types) - # type 1 - self.assertAlmostEqual( - stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 - ) - # about 40% are yes - self.assertAlmostEqual( - stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 - ) - # type 5/fish - self.assertAlmostEqual( - stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) - self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) - # type 5/fowl - self.assertAlmostEqual( - stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 - ) - self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) - - def test_create_with_null_partitioned_grouped_sampled_only(self): - """Test EAV for all columns with sampled generation but no suppression.""" - table_name = "measurement" - table2_name = "observation" - generate_count = 800 - with self._get_cmd({}) as gc: - self.merge_columns( - gc, table_name, ["type", "first_value", "second_value", "third_value"] - ) - gc.do_propose("") - proposals = gc.get_proposals() - self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) - self.assertIn("null-partitioned grouped_multivariate_normal", proposals) - self.assertIn( - "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", - proposals, - ) - self.assertIn( - "null-partitioned grouped_multivariate_normal [sampled and suppressed]", - proposals, - ) - self.assertIn( - "null-partitioned grouped_multivariate_lognormal [sampled]", proposals - ) - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" - self.assertIn(dist_to_choose, proposals) - prop = proposals[dist_to_choose] - gc.reset() - gc.do_compare(str(prop[0])) - col_heading = f"{prop[0]}. {dist_to_choose}" - self.assertIn(col_heading, set(gc.columns.keys())) - gc.do_set(str(prop[0])) - self.merge_columns( - gc, table2_name, ["type", "first_value", "second_value", "third_value"] - ) - gc.do_propose("") - proposals = gc.get_proposals() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - self.assert_subset({row.type for row in rows}, {1, 2, 3, 4, 5}) - stmt = select(self.metadata.tables[table2_name]) - rows = conn.execute(stmt).fetchall() - self.assertEqual( - {row.third_value for row in rows}, {"ham", "eggs", "cheese"} - ) - - def test_create_with_null_partitioned_grouped_sampled_tiny(self): - """ - Test EAV for all columns with sampled generation that only gets a tiny sample. - """ - # five will ensure that at least one group will have two elements in it, - # but all three cannot. - NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 - table_name = "observation" - generate_count = 100 - with self._get_cmd({}) as gc: - dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" - self.merge_columns( - gc, table_name, ["type", "first_value", "second_value", "third_value"] - ) - gc.do_propose("") - proposals = gc.get_proposals() - # breakpoint() - prop = proposals[dist_to_choose] - gc.do_set(str(prop[0])) - gc.do_quit("") - self.set_configuration(gc.config) - self.get_src_stats(gc.config) - self.create_generators(gc.config) - self.remove_data(gc.config) - self.populate_measurement_type_vocab() - self.create_data(gc.config, num_passes=generate_count) - with self.engine.connect() as conn: - stmt = select(self.metadata.tables[table_name]) - rows = conn.execute(stmt).fetchall() - # we should only have one or two of "ham", "eggs" and "cheese" represented - foods = {row.third_value for row in rows} - self.assert_subset(foods, {"ham", "eggs", "cheese"}) - self.assertLess(len(foods), 3) - - -class NonInteractiveTests(RequiresDBTestCase): - """ - Test the --spec SPEC_FILE option of configure-generators - """ - - dump_file_path = "eav.sql" - database_name = "eav" - schema_name = "public" - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - ["observation", "type", "dist_gen.weighted_choice [sampled]"], - [ - "observation", - "first_value", - "dist_gen.weighted_choice", - "dist_gen.constant", - ], - [ - "observation", - "second_value", - "dist_gen.weighted_choice", - "dist_gen.weighted_choice [sampled]", - "dist_gen.constant", - ], - ["observation", "third_value", "dist_gen.weighted_choice"], - ] - ), - ) - def test_non_interactive_configure_generators( - self, _mock_csv_reader: MagicMock, _mock_path: MagicMock - ) -> None: - """ - test that we can set generators from a CSV file - """ - config: MutableMapping[str, Any] = {} - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg["name"] - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") - self.assertEqual( - row_gens["observation['first_value']"], "dist_gen.weighted_choice" - ) - self.assertEqual(row_gens["observation['second_value']"], "dist_gen.constant") - self.assertEqual( - row_gens["observation['third_value']"], "dist_gen.weighted_choice" - ) - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - [ - "observation", - "type first_value second_value third_value", - "null-partitioned grouped_multivariate_lognormal", - ], - ] - ), - ) - def test_non_interactive_configure_null_partitioned( - self, mock_csv_reader: MagicMock, mock_path: MagicMock - ): - """ - test that we can set multi-column generators from a CSV file - """ - config = {} - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["name"], - "dist_gen.alternatives", - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["name"], - '"with_constants_at"', - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], - '"grouped_multivariate_lognormal"', - ) - - @patch("datafaker.interactive.Path") - @patch( - "datafaker.interactive.csv.reader", - return_value=iter( - [ - [ - "observation", - "type first_value second_value third_value", - "null-partitioned grouped_multivariate_lognormal", - ], - ] - ), - ) - def test_non_interactive_configure_null_partitioned_where_existing_merges( - self, _mock_csv_reader: MagicMock, _mock_path: MagicMock - ) -> None: - """ - test that we can set multi-column generators from a CSV file, - but where there are already multi-column generators configured - that will have to be unmerged. - """ - config = { - "tables": { - "observation": { - "row_generators": [ - { - "name": "arbitrary_gen", - "columns_assigned": [ - "type", - "second_value", - "first_value", - ], - } - ], - }, - }, - } - spec_csv = Mock(return_value="mock spec.csv file") - update_config_generators( - self.dsn, self.schema_name, self.metadata, config, spec_csv - ) - row_gens = { - f"{table}{sorted(rg['columns_assigned'])}": rg - for table, tables in config.get("tables", {}).items() - for rg in tables.get("row_generators", []) - } - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["name"], - "dist_gen.alternatives", - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["name"], - '"with_constants_at"', - ) - self.assertEqual( - row_gens[ - "observation['first_value', 'second_value', 'third_value', 'type']" - ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], - '"grouped_multivariate_lognormal"', - ) diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py new file mode 100644 index 0000000..3d5c312 --- /dev/null +++ b/tests/test_interactive_generators_partitioned.py @@ -0,0 +1,419 @@ +"""Tests for null-partitioned generators.""" +from collections.abc import MutableMapping +from dataclasses import dataclass +from typing import Any +from unittest import TestCase + +from sqlalchemy import Connection, MetaData, insert, select + +from datafaker.generators import NullPartitionedNormalGeneratorFactory +from tests.test_interactive_generators import TestGeneratorCmd +from tests.utils import GeneratesDBTestCase + + +@dataclass +class Stat: + """Mean and variance calculator.""" + + n: int = 0 + x: float = 0 + x2: float = 0 + + def add(self, x: float) -> None: + """Add one datum.""" + self.n += 1 + self.x += x + self.x2 += x * x + + def count(self) -> int: + """Get the number of data added.""" + return self.n + + def x_mean(self) -> float: + """Get the mean of the added data.""" + return self.x / self.n + + def x_var(self) -> float: + """Get the variance of the added data.""" + x = self.x + return (self.x2 - x * x / self.n) / (self.n - 1) + + +@dataclass +class Correlation(Stat): + """Mean, variance and covariance.""" + + y: float = 0 + y2: float = 0 + xy: float = 0 + + def add2(self, x: float, y: float) -> None: + """Add a 2D data point.""" + self.n += 1 + self.x += x + self.x2 += x * x + self.y += y + self.y2 += y * y + self.xy += x * y + + def y_mean(self) -> float: + """Get the mean of the second parts of the added points.""" + return self.y / self.n + + def y_var(self) -> float: + """Get the variance of the second parts of the added points.""" + y = self.y + return (self.y2 - y * y / self.n) / (self.n - 1) + + def covar(self) -> float: + """Get the covariance of the two parts of the added points.""" + return (self.xy - self.x * self.y / self.n) / (self.n - 1) + + +# pylint: disable=too-many-instance-attributes +class EavMeasurementTableStats: + """The statistics for the Measurement table of eav.sql.""" + + def __init__(self, conn: Connection, metadata: MetaData, test: TestCase) -> None: + stmt = select(metadata.tables["measurement"]) + rows = conn.execute(stmt).fetchall() + self.types: set[int] = set() + self.one_count = 0 + self.one_yes_count = 0 + self.two = Correlation() + self.three = Correlation() + self.four = Correlation() + self.fish = Stat() + self.fowl = Stat() + for row in rows: + self.types.add(row.type) + if row.type == 1: + # yes or no + test.assertIsNone(row.first_value) + test.assertIsNone(row.second_value) + test.assertIn(row.third_value, {"yes", "no"}) + self.one_count += 1 + if row.third_value == "yes": + self.one_yes_count += 1 + elif row.type == 2: + # positive correlation around 1.4, 1.8 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.two.add2(row.first_value, row.second_value) + elif row.type == 3: + # negative correlation around 11.8, 12.1 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.three.add2(row.first_value, row.second_value) + elif row.type == 4: + # positive correlation around 21.4, 23.4 + test.assertIsNotNone(row.first_value) + test.assertIsNotNone(row.second_value) + test.assertIsNone(row.third_value) + self.four.add2(row.first_value, row.second_value) + elif row.type == 5: + test.assertIn(row.third_value, {"fish", "fowl"}) + test.assertIsNotNone(row.first_value) + test.assertIsNone(row.second_value) + if row.third_value == "fish": + # mean 8.1 and sd 0.755 + self.fish.add(row.first_value) + else: + # mean 11.2 and sd 1.114 + self.fowl.add(row.first_value) + + +class NullPartitionedTests(GeneratesDBTestCase): + """Testing null-partitioned grouped multivariate generation.""" + + dump_file_path = "eav.sql" + database_name = "eav" + schema_name = "public" + + def setUp(self) -> None: + """Set up the test with specific sample and suppress counts.""" + super().setUp() + NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 8 + NullPartitionedNormalGeneratorFactory.SUPPRESS_COUNT = 2 + + def _get_cmd(self, config: MutableMapping[str, Any]) -> TestGeneratorCmd: + """Get the configure-generators object as our command.""" + return TestGeneratorCmd(self.dsn, self.schema_name, self.metadata, config) + + def _propose(self, gc: TestGeneratorCmd) -> dict[str, tuple[int, str, list[str]]]: + gc.reset() + gc.do_propose("") + return gc.get_proposals() + + def test_create_with_null_partitioned_grouped_multivariate(self) -> None: + """Test EAV for all columns.""" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, + "measurement", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) + dist_to_choose = "null-partitioned grouped_multivariate_normal" + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + gc.reset() + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stats = EavMeasurementTableStats(conn, self.metadata, self) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 20, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 2 + self.assertAlmostEqual( + stats.two.count(), generate_count * 3 / 20, delta=generate_count * 0.5 + ) + self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) + self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) + self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) + # type 3 + self.assertAlmostEqual( + stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.three.covar(), -2.085, delta=1.1) + # type 4 + self.assertAlmostEqual( + stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.6) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + + def populate_measurement_type_vocab(self) -> None: + """Add a vocab table without messing around with files""" + table = self.metadata.tables["measurement_type"] + with self.sync_engine.connect() as conn: + conn.execute(insert(table).values({"id": 1, "name": "agreement"})) + conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) + conn.execute(insert(table).values({"id": 3, "name": "velocity"})) + conn.execute(insert(table).values({"id": 4, "name": "position"})) + conn.execute(insert(table).values({"id": 5, "name": "matter"})) + conn.commit() + + def merge_columns( + self, gc: TestGeneratorCmd, table: str, columns: list[str] + ) -> None: + """Merge columns in a table""" + gc.do_next(f"{table}.{columns[0]}") + for col in columns[1:]: + gc.do_merge(col) + gc.reset() + + def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> None: + """Test EAV for all columns with sampled and suppressed generation.""" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, + "measurement", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + self.assert_subset( + { + "null-partitioned grouped_multivariate_lognormal", + "null-partitioned grouped_multivariate_normal", + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + }, + set(proposals), + ) + dist_to_choose = ( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]" + ) + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + self.merge_columns( + gc, + "observation", + [ + "type", + "first_value", + "second_value", + "third_value", + ], + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stats = EavMeasurementTableStats(conn, self.metadata, self) + stmt = select(self.metadata.tables["observation"]) + rows = conn.execute(stmt).fetchall() + firsts = Stat() + for row in rows: + stats.types.add(row.type) + self.assertEqual(row.type, 1) + self.assertIsNotNone(row.first_value) + self.assertIsNone(row.second_value) + self.assertIn(row.third_value, {"ham", "eggs"}) + firsts.add(row.first_value) + self.assertEqual(firsts.count(), 800) + self.assertAlmostEqual(firsts.x_mean(), 1.3, delta=generate_count * 0.3) + self.assert_subset(stats.types, {1, 2, 3, 4, 5}) + self.assertEqual(len(stats.types), 4) + self.assert_subset({1, 5}, stats.types) + # type 1 + self.assertAlmostEqual( + stats.one_count, generate_count * 5 / 11, delta=generate_count * 0.4 + ) + # about 40% are yes + self.assertAlmostEqual( + stats.one_yes_count / stats.one_count, 0.4, delta=generate_count * 0.4 + ) + # type 5/fish + self.assertAlmostEqual( + stats.fish.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fish.x_mean(), 8.1, delta=3.0) + self.assertAlmostEqual(stats.fish.x_var(), 0.855, delta=0.5) + # type 5/fowl + self.assertAlmostEqual( + stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 + ) + self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) + self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + + def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: + """Test EAV for all columns with sampled generation but no suppression.""" + table_name = "measurement" + table2_name = "observation" + generate_count = 800 + with self._get_cmd({}) as gc: + self.merge_columns( + gc, table_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + self.assertIn("null-partitioned grouped_multivariate_lognormal", proposals) + self.assertIn("null-partitioned grouped_multivariate_normal", proposals) + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled and suppressed]", + proposals, + ) + self.assertIn( + "null-partitioned grouped_multivariate_normal [sampled and suppressed]", + proposals, + ) + self.assertIn( + "null-partitioned grouped_multivariate_lognormal [sampled]", proposals + ) + dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" + self.assertIn(dist_to_choose, proposals) + prop = proposals[dist_to_choose] + gc.reset() + gc.do_compare(str(prop[0])) + col_heading = f"{prop[0]}. {dist_to_choose}" + self.assertIn(col_heading, set(gc.columns.keys())) + gc.do_set(str(prop[0])) + self.merge_columns( + gc, table2_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + self.assert_subset({row.type for row in rows}, {1, 2, 3, 4, 5}) + stmt = select(self.metadata.tables[table2_name]) + rows = conn.execute(stmt).fetchall() + self.assertEqual( + {row.third_value for row in rows}, {"ham", "eggs", "cheese"} + ) + + def test_create_with_null_partitioned_grouped_sampled_tiny(self) -> None: + """ + Test EAV for all columns with sampled generation that only gets a tiny sample. + """ + # five will ensure that at least one group will have two elements in it, + # but all three cannot. + NullPartitionedNormalGeneratorFactory.SAMPLE_COUNT = 5 + table_name = "observation" + generate_count = 100 + with self._get_cmd({}) as gc: + dist_to_choose = "null-partitioned grouped_multivariate_normal [sampled]" + self.merge_columns( + gc, table_name, ["type", "first_value", "second_value", "third_value"] + ) + proposals = self._propose(gc) + prop = proposals[dist_to_choose] + gc.do_set(str(prop[0])) + gc.do_quit("") + self.set_configuration(gc.config) + self.get_src_stats(gc.config) + self.create_generators(gc.config) + self.remove_data(gc.config) + self.populate_measurement_type_vocab() + self.create_data(gc.config, num_passes=generate_count) + with self.sync_engine.connect() as conn: + stmt = select(self.metadata.tables[table_name]) + rows = conn.execute(stmt).fetchall() + # we should only have one or two of "ham", "eggs" and "cheese" represented + foods = {row.third_value for row in rows} + self.assert_subset(foods, {"ham", "eggs", "cheese"}) + self.assertLess(len(foods), 3) diff --git a/tests/test_noninteractive_generators.py b/tests/test_noninteractive_generators.py new file mode 100644 index 0000000..9343114 --- /dev/null +++ b/tests/test_noninteractive_generators.py @@ -0,0 +1,179 @@ +""" Tests for the configure-generators command with the --spec option. """ + +from collections.abc import Mapping, MutableMapping +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +from datafaker.interactive import update_config_generators +from tests.utils import RequiresDBTestCase + + +class NonInteractiveTests(RequiresDBTestCase): + """ + Test the --spec SPEC_FILE option of configure-generators + """ + + dump_file_path = "eav.sql" + database_name = "eav" + schema_name = "public" + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + ["observation", "type", "dist_gen.weighted_choice [sampled]"], + [ + "observation", + "first_value", + "dist_gen.weighted_choice", + "dist_gen.constant", + ], + [ + "observation", + "second_value", + "dist_gen.weighted_choice", + "dist_gen.weighted_choice [sampled]", + "dist_gen.constant", + ], + ["observation", "third_value", "dist_gen.weighted_choice"], + ] + ), + ) + def test_non_interactive_configure_generators( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set generators from a CSV file + """ + config: MutableMapping[str, Any] = {} + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens = { + f"{table}{sorted(rg['columns_assigned'])}": rg["name"] + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual(row_gens["observation['type']"], "dist_gen.weighted_choice") + self.assertEqual( + row_gens["observation['first_value']"], "dist_gen.weighted_choice" + ) + self.assertEqual(row_gens["observation['second_value']"], "dist_gen.constant") + self.assertEqual( + row_gens["observation['third_value']"], "dist_gen.weighted_choice" + ) + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + [ + "observation", + "type first_value second_value third_value", + "null-partitioned grouped_multivariate_lognormal", + ], + ] + ), + ) + def test_non_interactive_configure_null_partitioned( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set multi-column generators from a CSV file + """ + config: MutableMapping[str, Any] = {} + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens = { + f"{table}{sorted(rg['columns_assigned'])}": rg + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["name"], + "dist_gen.alternatives", + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["name"], + '"with_constants_at"', + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], + '"grouped_multivariate_lognormal"', + ) + + @patch("datafaker.interactive.Path") + @patch( + "datafaker.interactive.csv.reader", + return_value=iter( + [ + [ + "observation", + "type first_value second_value third_value", + "null-partitioned grouped_multivariate_lognormal", + ], + ] + ), + ) + def test_non_interactive_configure_null_partitioned_where_existing_merges( + self, _mock_csv_reader: MagicMock, _mock_path: MagicMock + ) -> None: + """ + test that we can set multi-column generators from a CSV file, + but where there are already multi-column generators configured + that will have to be unmerged. + """ + config = { + "tables": { + "observation": { + "row_generators": [ + { + "name": "arbitrary_gen", + "columns_assigned": [ + "type", + "second_value", + "first_value", + ], + } + ], + }, + }, + } + spec_csv = Mock(return_value="mock spec.csv file") + update_config_generators( + self.dsn, self.schema_name, self.metadata, config, spec_csv + ) + row_gens: Mapping[str, Any] = { + f"{table}{sorted(rg['columns_assigned'])}": rg + for table, tables in config.get("tables", {}).items() + for rg in tables.get("row_generators", []) + } + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["name"], + "dist_gen.alternatives", + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["name"], + '"with_constants_at"', + ) + self.assertEqual( + row_gens[ + "observation['first_value', 'second_value', 'third_value', 'type']" + ]["kwargs"]["alternative_configs"][0]["params"]["subgen"], + '"grouped_multivariate_lognormal"', + ) From 05ea3780c67eeefe6a744b1c8127b2087934b3d8 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 15 Oct 2025 18:18:52 +0100 Subject: [PATCH 22/35] Add running tests to pre-commit.yml --- .github/workflows/pre-commit.yml | 4 ++++ CONTRIBUTING.md | 11 ----------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index b07de4d..3895121 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -59,3 +59,7 @@ jobs: shell: bash run: | pre-commit run --all-files + - name: Run tests + shell: bash + run: | + python -m unittest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8d7b879..234f8fb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -47,17 +47,6 @@ Executing unit tests is straightforward: python -m unittest discover --verbose tests/ ``` -for tests that are currently maintained. - -## Running functional tests - -These tests do not currently work, and will be replaced by unit tests. - -Functional tests require PostgreSQL to be installed. - -..warning:: - Some MacOS systems [do not recognise the 'en_US.utf8' locale](https://apple.stackexchange.com/questions/206495/load-a-locale-from-usr-local-share-locale-in-os-x). As a workaround, replace `en_US.utf8` with `en_US.UTF-8` on every `*.dump` file. - ## Building documentation locally ```bash From 91036ceff87bfe5b82f904dfdd0563e27ee8d70e Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 15 Oct 2025 18:25:37 +0100 Subject: [PATCH 23/35] Github actions starting PostgreSQL --- .github/workflows/pre-commit.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3895121..5773314 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -59,6 +59,10 @@ jobs: shell: bash run: | pre-commit run --all-files + - name: Start PostgreSQL + shell: bash + run: | + sudo systemctl start postgresql.service - name: Run tests shell: bash run: | From 69e0933bc059f586a6650b4a96de778748f9948f Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:08:53 +0100 Subject: [PATCH 24/35] Fixed test_unique_constraint_fails --- .pre-commit-config.yaml | 5 +---- .pylintrc | 3 +-- tests/test_functional.py | 9 ++++++++- tests/test_utils.py | 12 +++++++++--- tests/workspace/.gitignore | 1 - tests/workspace/README.md | 3 --- 6 files changed, 19 insertions(+), 14 deletions(-) delete mode 100644 tests/workspace/.gitignore delete mode 100644 tests/workspace/README.md diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04464f9..a99928f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,8 +40,7 @@ repos: language: system types: ['python'] exclude: (?x)( - tests/examples| - tests/workspace + tests/examples ) - id: isort name: isort @@ -50,7 +49,6 @@ repos: types: ['python'] exclude: (?x)( tests/examples| - tests/workspace| examples ) - id: pylint @@ -77,7 +75,6 @@ repos: language: system exclude: (?x)( tests/examples| - tests/workspace| examples ) types: ['python'] diff --git a/.pylintrc b/.pylintrc index 22a92bd..cb276e2 100644 --- a/.pylintrc +++ b/.pylintrc @@ -24,8 +24,7 @@ ignore=CVS # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths. -ignore-paths=tests/examples, - tests/workspace +ignore-paths=tests/examples # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/tests/test_functional.py b/tests/test_functional.py index ac7e51a..a7374fb 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,6 +2,7 @@ import os import shutil from pathlib import Path +import tempfile from typing import Any, Mapping from sqlalchemy import create_engine, inspect @@ -20,7 +21,6 @@ class DBFunctionalTestCase(RequiresDBTestCase): database_name = "src" schema_name = "public" - test_dir = Path("tests/workspace") examples_dir = Path("tests/examples") orm_file_path = Path("orm.yaml") @@ -67,6 +67,7 @@ def setUp(self) -> None: ) # Copy some of the example files over to the workspace. + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) for file in self.generator_file_paths + (self.config_file_path,): src = self.examples_dir / file dst = self.test_dir / file @@ -501,6 +502,12 @@ def test_unique_constraint_fail(self) -> None: f"--orm-file={self.alt_orm_file_path}", "--force", ) + self.invoke( + "make-vocab", + f"--orm-file={self.alt_orm_file_path}", + f"--config-file={self.config_file_path}", + "--force", + ) self.invoke( "make-stats", f"--stats-file={self.stats_file_path}", diff --git a/tests/test_utils.py b/tests/test_utils.py index 54c167a..808186f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,9 @@ """Tests for the utils module.""" +import importlib import os import sys from pathlib import Path +import tempfile from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert @@ -14,6 +16,7 @@ read_config_file, ) from tests.utils import DatafakerTestCase, RequiresDBTestCase +from . import examples # pylint: disable=invalid-name Base = declarative_base() @@ -60,7 +63,6 @@ class TestDownload(RequiresDBTestCase): dump_file_path = "providers.dump" mytable_file_path = Path("mytable.yaml") - test_dir = Path("tests/workspace") start_dir = os.getcwd() def setUp(self) -> None: @@ -69,6 +71,7 @@ def setUp(self) -> None: metadata.create_all(self.engine) + self.test_dir = Path(tempfile.mkdtemp(prefix="df-")) os.chdir(self.test_dir) self.mytable_file_path.unlink(missing_ok=True) @@ -90,8 +93,11 @@ def test_download_table(self) -> None: ) # The .strip() gets rid of any possible empty lines at the end of the file. - with Path("../examples/expected.yaml").open(encoding="utf-8") as yamlfile: - expected = yamlfile.read().strip() + with importlib.resources.as_file( + importlib.resources.files(examples) / "expected.yaml" + ) as yamlpath: + with yamlpath.open() as yamlfile: + expected = yamlfile.read().strip() with self.mytable_file_path.open(encoding="utf-8") as yamlfile: actual = yamlfile.read().strip() diff --git a/tests/workspace/.gitignore b/tests/workspace/.gitignore deleted file mode 100644 index 72e8ffc..0000000 --- a/tests/workspace/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* diff --git a/tests/workspace/README.md b/tests/workspace/README.md deleted file mode 100644 index 8165a69..0000000 --- a/tests/workspace/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Test Workspace - -A workspace for the functional tests to run in. From 820700e637dbcca58ae4553ae1cd7fe100524ffc Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:18:41 +0100 Subject: [PATCH 25/35] Cleaned up --- tests/test_functional.py | 2 +- tests/test_utils.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index a7374fb..dc7ef48 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,8 +1,8 @@ """Tests for the CLI.""" import os import shutil -from pathlib import Path import tempfile +from pathlib import Path from typing import Any, Mapping from sqlalchemy import create_engine, inspect diff --git a/tests/test_utils.py b/tests/test_utils.py index 808186f..aab9ee0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,9 @@ """Tests for the utils module.""" -import importlib import os import sys -from pathlib import Path import tempfile +from importlib import resources +from pathlib import Path from unittest.mock import MagicMock, call, patch from sqlalchemy import Column, Integer, insert @@ -16,6 +16,7 @@ read_config_file, ) from tests.utils import DatafakerTestCase, RequiresDBTestCase + from . import examples # pylint: disable=invalid-name @@ -93,10 +94,8 @@ def test_download_table(self) -> None: ) # The .strip() gets rid of any possible empty lines at the end of the file. - with importlib.resources.as_file( - importlib.resources.files(examples) / "expected.yaml" - ) as yamlpath: - with yamlpath.open() as yamlfile: + with resources.as_file(resources.files(examples) / "expected.yaml") as yamlpath: + with yamlpath.open(encoding="utf-8") as yamlfile: expected = yamlfile.read().strip() with self.mytable_file_path.open(encoding="utf-8") as yamlfile: From 433bb19e073cc4f4863b219de9d0839c6212138b Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:46:06 +0100 Subject: [PATCH 26/35] Fixed tests --- tests/test_interactive_generators_partitioned.py | 4 ++-- tests/test_utils.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index 3d5c312..f544be8 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -196,8 +196,8 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: self.assertAlmostEqual(stats.two.x_mean(), 1.4, delta=0.4) self.assertAlmostEqual(stats.two.x_var(), 0.315, delta=0.18) self.assertAlmostEqual(stats.two.y_mean(), 1.8, delta=0.8) - self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.06) - self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.07) + self.assertAlmostEqual(stats.two.y_var(), 0.105, delta=0.08) + self.assertAlmostEqual(stats.two.covar(), 0.105, delta=0.08) # type 3 self.assertAlmostEqual( stats.three.count(), generate_count * 3 / 20, delta=generate_count * 0.2 diff --git a/tests/test_utils.py b/tests/test_utils.py index aab9ee0..0d2427e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ """Tests for the utils module.""" +import importlib.util import os import sys import tempfile @@ -17,8 +18,6 @@ ) from tests.utils import DatafakerTestCase, RequiresDBTestCase -from . import examples - # pylint: disable=invalid-name Base = declarative_base() # pylint: enable=invalid-name @@ -94,7 +93,10 @@ def test_download_table(self) -> None: ) # The .strip() gets rid of any possible empty lines at the end of the file. - with resources.as_file(resources.files(examples) / "expected.yaml") as yamlpath: + tests_module = sys.modules["tests"] + with resources.as_file( + resources.files(tests_module) / "examples" / "expected.yaml" + ) as yamlpath: with yamlpath.open(encoding="utf-8") as yamlfile: expected = yamlfile.read().strip() From d1b07dc7d35cb3012ce36bbb387d1bf645351aee Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 16 Oct 2025 18:52:29 +0100 Subject: [PATCH 27/35] cleaned --- tests/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d2427e..ac82d12 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,4 @@ """Tests for the utils module.""" -import importlib.util import os import sys import tempfile From 79990a1c08252b06f19fbdbffa1bf4df5ab71599 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:06:37 +0100 Subject: [PATCH 28/35] Move real test runner to tests.yml, overwriting bad test runner --- .github/workflows/pre-commit.yml | 10 +------- .github/workflows/tests.yml | 44 ++++---------------------------- 2 files changed, 6 insertions(+), 48 deletions(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 5773314..ec27ef8 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -11,7 +11,7 @@ env: PRE_COMMIT_HOME: ~/.caches/pre-commit PYTHON_VERSION: "3.10" jobs: - the_job: + clean-check: runs-on: ubuntu-latest steps: - name: Checkout Code @@ -59,11 +59,3 @@ jobs: shell: bash run: | pre-commit run --all-files - - name: Start PostgreSQL - shell: bash - run: | - sudo systemctl start postgresql.service - - name: Run tests - shell: bash - run: | - python -m unittest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 75f45f0..9d79e81 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,48 +10,14 @@ env: # This should be the default but we'll be explicit PYTHON_VERSION: "3.10" jobs: - the_job: + unit-tests: runs-on: ubuntu-latest - services: - postgres: - image: postgres:15 - env: - POSTGRES_PASSWORD: password - ports: - - 5432:5432 steps: - - name: Checkout Code - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ env.PYTHON_VERSION }} - - name: Bootstrap poetry + - name: Start PostgreSQL shell: bash run: | - python -m ensurepip - python -m pip install --upgrade pip - python -m pip install poetry - - name: Configure poetry + sudo systemctl start postgresql.service + - name: Run tests shell: bash run: | - python -m poetry config virtualenvs.in-project true - # - name: Cache Poetry dependencies - # uses: actions/cache@v3 - # id: poetry-cache - # with: - # path: .venv - # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} - - name: Install dependencies - shell: bash - if: steps.poetry-cache.outputs.cache-hit != 'true' - run: | - python -m poetry install --all-extras - - name: Create src database - shell: bash - run: | - PGPASSWORD=password psql --host=localhost --username=postgres --set="ON_ERROR_STOP=1" --file=tests/examples/src.dump - - name: Run Unit Tests - shell: bash - run: | - REQUIRES_DB=1 poetry run python -m unittest discover --verbose tests + python -m unittest From 9d0ec4775d60154de9ecd036ea3ae314d75068f5 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:10:57 +0100 Subject: [PATCH 29/35] Added poetry initialisation to test runner --- .github/workflows/tests.yml | 20 ++++++++++++++++++++ tests/test_functional.py | 3 --- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d79e81..8a9b310 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,6 +17,26 @@ jobs: shell: bash run: | sudo systemctl start postgresql.service + - name: Bootstrap poetry + shell: bash + run: | + python -m ensurepip + python -m pip install --upgrade pip + python -m pip install poetry + - name: Configure poetry + shell: bash + run: | + python -m poetry config virtualenvs.in-project true + # - name: Cache Poetry dependencies uses: actions/cache@v3 + # id: poetry-cache + # with: + # path: .venv + # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} + - name: Install dependencies + shell: bash + if: steps.poetry-cache.outputs.cache-hit != 'true' + run: | + python -m poetry install --all-extras - name: Run tests shell: bash run: | diff --git a/tests/test_functional.py b/tests/test_functional.py index dc7ef48..bfb2f09 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -29,9 +29,6 @@ class DBFunctionalTestCase(RequiresDBTestCase): alt_orm_file_path = Path("my_orm.yaml") alt_datafaker_file_path = Path("my_df.py") - vocabulary_file_paths = tuple( - map(Path, ("concept.yaml", "concept_type.yaml", "mitigation_type.yaml")), - ) generator_file_paths = tuple( map(Path, ("story_generators.py", "row_generators.py")), ) From 83438bb56a276a91578adf8e49f907086e8e6684 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:47:43 +0100 Subject: [PATCH 30/35] More test fixes --- .github/workflows/tests.yml | 13 +++---------- tests/test_interactive_generators_partitioned.py | 6 +++--- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8a9b310..964d943 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,21 +17,14 @@ jobs: shell: bash run: | sudo systemctl start postgresql.service - - name: Bootstrap poetry + - name: Install poetry shell: bash run: | - python -m ensurepip - python -m pip install --upgrade pip - python -m pip install poetry + sudo apt install python3-poetry - name: Configure poetry shell: bash run: | python -m poetry config virtualenvs.in-project true - # - name: Cache Poetry dependencies uses: actions/cache@v3 - # id: poetry-cache - # with: - # path: .venv - # key: venv-${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }} - name: Install dependencies shell: bash if: steps.poetry-cache.outputs.cache-hit != 'true' @@ -40,4 +33,4 @@ jobs: - name: Run tests shell: bash run: | - python -m unittest + poetry run python -m unittest diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index f544be8..3b21b53 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -207,7 +207,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: self.assertAlmostEqual( stats.four.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) - self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1) + self.assertAlmostEqual(stats.four.covar(), 3.33, delta=1.3) # type 5/fish self.assertAlmostEqual( stats.fish.count(), generate_count * 3 / 20, delta=generate_count * 0.2 @@ -219,7 +219,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: stats.fowl.count(), generate_count * 3 / 20, delta=generate_count * 0.2 ) self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + self.assertAlmostEqual(stats.fowl.x_var(), 1.24, delta=0.6) def populate_measurement_type_vocab(self) -> None: """Add a vocab table without messing around with files""" @@ -330,7 +330,7 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No stats.fowl.count(), generate_count * 3 / 11, delta=generate_count * 0.2 ) self.assertAlmostEqual(stats.fowl.x_mean(), 11.2, delta=8.0) - self.assertAlmostEqual(stats.fowl.x_var(), 1.86, delta=1) + self.assertAlmostEqual(stats.fowl.x_var(), 1.24, delta=0.6) def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: """Test EAV for all columns with sampled generation but no suppression.""" From 2306b12dd0490a0231b35ff9f4a9b318ad3e64aa Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 17 Oct 2025 10:55:54 +0100 Subject: [PATCH 31/35] Another attempt to get tests.yml working --- .github/workflows/tests.yml | 2 ++ .pre-commit-config.yaml | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 964d943..389e073 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,6 +13,8 @@ jobs: unit-tests: runs-on: ubuntu-latest steps: + - name: Checkout Code + uses: actions/checkout@v3 - name: Start PostgreSQL shell: bash run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a99928f..085ce54 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude: docs/(source|build/html)/_static/ @@ -13,12 +13,12 @@ repos: - id: check-added-large-files - repo: https://github.com/markdownlint/markdownlint # Note the "v" - rev: v0.11.0 + rev: v0.12.0 hooks: - id: markdownlint args: [--style=mdl_style.rb] - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.8.0.4 + rev: v0.11.0.1 hooks: - id: shellcheck - repo: local From 079598f875beafe02528d4ecbfc4c017fb020f30 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 22 Oct 2025 15:06:00 +0100 Subject: [PATCH 32/35] Fixed choice variant proposals --- datafaker/generators/choice.py | 39 ++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py index 54f69d3..5579d88 100644 --- a/datafaker/generators/choice.py +++ b/datafaker/generators/choice.py @@ -325,7 +325,31 @@ def get_generators( table_name, column_name, vg.cvs, vg.counts ), ] - results = connection.execute( + if vg.counts_not_suppressed: + generators += [ + ZipfChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + suppress_count=self.SUPPRESS_COUNT, + ), + UniformChoiceGenerator( + table_name, + column_name, + vg.values_not_suppressed, + vg.counts_not_suppressed, + suppress_count=self.SUPPRESS_COUNT, + ), + WeightedChoiceGenerator( + table_name=table_name, + column_name=column_name, + values=vg.cvs_not_suppressed, + counts=vg.counts_not_suppressed, + suppress_count=self.SUPPRESS_COUNT, + ), + ] + sampled_results = connection.execute( text( f"SELECT v, COUNT(v) AS f FROM" f" (SELECT {column_name} as v FROM {table_name}" @@ -333,20 +357,9 @@ def get_generators( f" AS _inner GROUP BY v ORDER BY f DESC" ) ) - if results is not None: + if sampled_results is not None: vg = ValueGatherer(results, self.SUPPRESS_COUNT) if vg.counts: - generators += [ - ZipfChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - UniformChoiceGenerator( - table_name, column_name, vg.values, vg.counts - ), - WeightedChoiceGenerator( - table_name, column_name, vg.cvs, vg.counts - ), - ] generators += [ ZipfChoiceGenerator( table_name, From 89a5d57f06e209eddb40c46d4bed7e4547ffd616 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 22 Oct 2025 18:42:24 +0100 Subject: [PATCH 33/35] MetaData in df.py fixed Allow remove-tables --all Fixed sampled choices Removed surplus df.py stuff --- datafaker/base.py | 4 ++-- datafaker/create.py | 16 +++++++++++----- datafaker/generators/choice.py | 8 ++++---- datafaker/main.py | 17 +++++++++++++---- datafaker/remove.py | 5 ++++- datafaker/templates/df.py.j2 | 16 +--------------- tests/test_create.py | 8 ++++++-- tests/test_interactive_generators.py | 2 ++ tests/test_main.py | 4 ++++ tests/utils.py | 1 + 10 files changed, 48 insertions(+), 33 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index f75591c..9d30422 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -15,7 +15,7 @@ import yaml from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.schema import Table +from sqlalchemy.schema import MetaData, Table from datafaker.utils import ( MAKE_VOCAB_PROGRESS_REPORT_EVERY, @@ -389,7 +389,7 @@ class TableGenerator(ABC): num_rows_per_pass: int = 1 @abstractmethod - def __call__(self, dst_db_conn: Connection) -> dict[str, Any]: + def __call__(self, dst_db_conn: Connection, metadata: MetaData) -> dict[str, Any]: """Return, as a dictionary, a new row for the table that we are generating. The only argument, `dst_db_conn`, should be a database connection to the diff --git a/datafaker/create.py b/datafaker/create.py index a877320..5cddeb2 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -100,6 +100,7 @@ def create_db_data( sorted_tables: Sequence[Table], df_module: ModuleType, num_passes: int, + metadata: MetaData, ) -> RowCounts: """Connect to a database and populate it with data.""" settings = get_settings() @@ -112,15 +113,18 @@ def create_db_data( num_passes, dst_dsn, settings.dst_schema, + metadata, ) +# pylint: disable=too-many-arguments too-many-positional-arguments def create_db_data_into( sorted_tables: Sequence[Table], df_module: ModuleType, num_passes: int, db_dsn: str, schema_name: str | None, + metadata: MetaData, ) -> RowCounts: """ Populate the database. @@ -145,6 +149,7 @@ def create_db_data_into( sorted_tables, df_module.table_generator_dict, df_module.story_generator_list, + metadata, ) return row_counts @@ -195,7 +200,7 @@ def table_name(self) -> str | None: """ return self._table_name - def insert(self) -> None: + def insert(self, metadata: MetaData) -> None: """ Put the row in the table. @@ -207,7 +212,7 @@ def insert(self) -> None: table = self._table_dict[self._table_name] if table.name in self._table_generator_dict: table_generator = self._table_generator_dict[table.name] - default_values = table_generator(self._dst_conn) + default_values = table_generator(self._dst_conn, metadata) else: default_values = {} insert_values = {**default_values, **self._provided_values} @@ -253,6 +258,7 @@ def populate( tables: Sequence[Table], table_generator_dict: Mapping[str, TableGenerator], story_generator_list: Sequence[Mapping[str, Any]], + metadata: MetaData, ) -> RowCounts: """Populate a database schema with synthetic data.""" row_counts: Counter[str] = Counter() @@ -277,7 +283,7 @@ def populate( for table in tables: # Do we have a story row to enter into this table? if story_iterator.has_table(table.name): - story_iterator.insert() + story_iterator.insert(metadata) row_counts[table.name] = row_counts.get(table.name, 0) + 1 story_iterator.next() if table.name not in table_generator_dict: @@ -291,7 +297,7 @@ def populate( try: with dst_conn.begin(): for _ in range(table_generator.num_rows_per_pass): - stmt = insert(table).values(table_generator(dst_conn)) + stmt = insert(table).values(table_generator(dst_conn, metadata)) dst_conn.execute(stmt) row_counts[table.name] = row_counts.get(table.name, 0) + 1 dst_conn.commit() @@ -301,7 +307,7 @@ def populate( # Insert any remaining stories while not story_iterator.is_ended(): - story_iterator.insert() + story_iterator.insert(metadata) t = story_iterator.table_name() if t is None: raise AssertionError( diff --git a/datafaker/generators/choice.py b/datafaker/generators/choice.py index 5579d88..aa1d838 100644 --- a/datafaker/generators/choice.py +++ b/datafaker/generators/choice.py @@ -306,8 +306,8 @@ def get_generators( with engine.connect() as connection: results = connection.execute( text( - f"SELECT {column_name} AS v, COUNT({column_name})" - f" AS f FROM {table_name} GROUP BY v" + f'SELECT "{column_name}" AS v, COUNT("{column_name}")' + f' AS f FROM "{table_name}" GROUP BY v' f" ORDER BY f DESC LIMIT {MAXIMUM_CHOICES + 1}" ) ) @@ -352,13 +352,13 @@ def get_generators( sampled_results = connection.execute( text( f"SELECT v, COUNT(v) AS f FROM" - f" (SELECT {column_name} as v FROM {table_name}" + f' (SELECT "{column_name}" as v FROM "{table_name}"' f" ORDER BY RANDOM() LIMIT {self.SAMPLE_COUNT})" f" AS _inner GROUP BY v ORDER BY f DESC" ) ) if sampled_results is not None: - vg = ValueGatherer(results, self.SUPPRESS_COUNT) + vg = ValueGatherer(sampled_results, self.SUPPRESS_COUNT) if vg.counts: generators += [ ZipfChoiceGenerator( diff --git a/datafaker/main.py b/datafaker/main.py index 22cf0ef..5a89d5e 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -158,6 +158,7 @@ def create_data( sorted_non_vocabulary_tables(orm_metadata, config), df_module, num_passes, + orm_metadata, ) logger.debug( "Data created in %s %s.", @@ -543,7 +544,12 @@ def remove_vocab( @app.command() def remove_tables( orm_file: str = Option(ORM_FILENAME, help="The name of the ORM yaml file"), - config_file: Optional[str] = Option(CONFIG_FILENAME, help="The configuration file"), + config_file: str = Option(CONFIG_FILENAME, help="The configuration file"), + # pylint: disable=redefined-builtin + all: bool = Option( + False, + help="Don't use the ORM file, delete all tables in the destination schema", + ), yes: bool = Option( False, "--yes", prompt="Are you sure?", help="Just remove, don't ask first" ), @@ -554,9 +560,12 @@ def remove_tables( """ if yes: logger.debug("Dropping tables.") - config = read_config_file(config_file) if config_file is not None else {} - metadata = load_metadata_for_output(orm_file, config) - remove_db_tables(metadata) + if all: + remove_db_tables(None) + else: + config = read_config_file(config_file) + metadata = load_metadata_for_output(orm_file, config) + remove_db_tables(metadata) logger.debug("Tables dropped.") else: logger.info("Would remove tables if called with --yes.") diff --git a/datafaker/remove.py b/datafaker/remove.py index a316619..3924cda 100644 --- a/datafaker/remove.py +++ b/datafaker/remove.py @@ -56,11 +56,14 @@ def remove_db_vocab( reinstate_vocab_foreign_key_constraints(metadata, meta_dict, config, dst_conn) -def remove_db_tables(metadata: MetaData) -> None: +def remove_db_tables(metadata: MetaData | None) -> None: """Drop the tables in the destination schema.""" settings = get_settings() assert settings.dst_dsn, "Missing destination database settings" dst_engine = get_sync_engine( create_db_engine(settings.dst_dsn, schema_name=settings.dst_schema) ) + if metadata is None: + metadata = MetaData() + metadata.reflect(dst_engine) metadata.drop_all(dst_engine) diff --git a/datafaker/templates/df.py.j2 b/datafaker/templates/df.py.j2 index 87c84e4..3a2da41 100644 --- a/datafaker/templates/df.py.j2 +++ b/datafaker/templates/df.py.j2 @@ -4,7 +4,6 @@ from mimesis.locales import Locale import sqlalchemy import sys from datafaker.base import FileUploader, TableGenerator, DistributionGenerator, ColumnPresence -from datafaker.main import load_metadata generic = Generic(locale=Locale.EN_GB) numeric = Numeric() @@ -23,8 +22,6 @@ from datafaker.providers import ( generic.add_provider({{ provider_import }}) {% endfor %} -metadata = load_metadata("{{ orm_file_name }}", "{{ config_file_name }}") - {% if row_generator_module_name is not none %} import {{ row_generator_module_name }} {% endif %} @@ -44,10 +41,6 @@ with open("{{ src_stats_filename }}", "r", encoding="utf-8") as f: SRC_STATS = yaml.unsafe_load(f) {% endif %} -{% for table_data in vocabulary_tables %} -{{ table_data.variable_name }} = FileUploader(metadata.tables["{{ table_data.table_name }}"]) -{% endfor %} - {% for table_data in tables %} class {{ table_data.class_name }}(TableGenerator): num_rows_per_pass = {{ table_data.rows_per_pass }} @@ -55,7 +48,7 @@ class {{ table_data.class_name }}(TableGenerator): def __init__(self): self.initialized = False - def __call__(self, dst_db_conn): + def __call__(self, dst_db_conn, metadata): if not self.initialized: {% for constraint in table_data.unique_constraints %} query_text = f"SELECT {% @@ -123,13 +116,6 @@ table_generator_dict = { {% endfor %} } - -vocab_dict = { -{% for table_data in vocabulary_tables %} - "{{ table_data.dictionary_entry }}": {{ table_data.variable_name }}, -{% endfor %} -} - {% for gen_data in story_generators %} def {{ gen_data.wrapper_name }}(dst_db_conn): return {{ gen_data.function_call.function_name }}({{ gen_data.function_call.argument_values| join(", ") }}) diff --git a/tests/test_create.py b/tests/test_create.py index b175f07..933fbe2 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, call, patch from sqlalchemy import Connection, select -from sqlalchemy.schema import Table +from sqlalchemy.schema import MetaData, Table from datafaker.base import TableGenerator from datafaker.create import create_db_vocab, populate @@ -90,6 +90,7 @@ def test_make_table_generators(self) -> None: class TestPopulate(DatafakerTestCase): """Test create.populate.""" + # pylint: disable=too-many-locals def test_populate(self) -> None: """Test the populate function.""" table_name = "table_name" @@ -111,6 +112,7 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: mock_dst_conn.execute.return_value.returned_defaults = {} mock_table = MagicMock(spec=Table) mock_table.name = table_name + mock_metadata = MagicMock(spec=MetaData) mock_gen = MagicMock(spec=TableGenerator) mock_gen.num_rows_per_pass = num_rows_per_pass mock_gen.return_value = {} @@ -134,6 +136,7 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]: [mock_table], {table_name: mock_gen}, story_generators, + mock_metadata, ) expected_row_count = ( @@ -165,13 +168,14 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None: mock_table_two.name = "two" mock_table_three = MagicMock(spec=Table) mock_table_three.name = "three" + mock_metadata = MagicMock(spec=MetaData) tables: list[Table] = [mock_table_one, mock_table_two, mock_table_three] row_generators: dict[str, TableGenerator] = { "two": mock_gen_two, "three": mock_gen_three, } - row_counts = populate(mock_dst_conn, tables, row_generators, []) + row_counts = populate(mock_dst_conn, tables, row_generators, [], mock_metadata) self.assertEqual(row_counts, {"two": 1, "three": 1}) self.assertListEqual( [call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py index a7eb757..104f0d3 100644 --- a/tests/test_interactive_generators.py +++ b/tests/test_interactive_generators.py @@ -670,6 +670,7 @@ def test_create_with_weighted_choice(self) -> None: { "dist_gen.weighted_choice", "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [suppressed]", "dist_gen.weighted_choice [sampled and suppressed]", }, set(proposals), @@ -691,6 +692,7 @@ def test_create_with_weighted_choice(self) -> None: { "dist_gen.weighted_choice", "dist_gen.weighted_choice [sampled]", + "dist_gen.weighted_choice [suppressed]", "dist_gen.weighted_choice [sampled and suppressed]", }, set(proposals), diff --git a/tests/test_main.py b/tests/test_main.py index fc388fb..2167207 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -216,8 +216,11 @@ def test_create_tables( @patch("datafaker.main.logger") @patch("datafaker.main.import_file") @patch("datafaker.main.create_db_data") + @patch("datafaker.main.load_metadata_for_output") + # pylint: disable=too-many-arguments too-many-positional-arguments def test_create_data( self, + mock_load_metadata: MagicMock, mock_create: MagicMock, mock_import: MagicMock, mock_logger: MagicMock, @@ -241,6 +244,7 @@ def test_create_data( mock_tables.return_value, mock_import.return_value, 1, + mock_load_metadata.return_value, ) self.assertSuccess(result) diff --git a/tests/utils.py b/tests/utils.py index ab6f1d2..2810c41 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -272,6 +272,7 @@ def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: num_passes, self.dsn, self.schema_name, + self.metadata, ) def generate_data( From b9b23537b90f0100a1718d06981ca299e08f4800 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Fri, 24 Oct 2025 18:55:11 +0100 Subject: [PATCH 34/35] Moved DIstributionGenerator to providers.py --- datafaker/base.py | 365 +------------------------------ datafaker/generators/base.py | 4 +- datafaker/generators/mimesis.py | 4 +- datafaker/providers.py | 372 +++++++++++++++++++++++++++++++- datafaker/templates/df.py.j2 | 5 +- 5 files changed, 380 insertions(+), 370 deletions(-) diff --git a/datafaker/base.py b/datafaker/base.py index 9d30422..495ff62 100644 --- a/datafaker/base.py +++ b/datafaker/base.py @@ -1,17 +1,14 @@ """Base table generator classes.""" -import functools import gzip -import math import os import random from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable from dataclasses import dataclass from io import TextIOWrapper from pathlib import Path -from typing import Any, Generator +from typing import Any -import numpy as np import yaml from sqlalchemy import Connection, insert from sqlalchemy.exc import SQLAlchemyError @@ -19,370 +16,12 @@ from datafaker.utils import ( MAKE_VOCAB_PROGRESS_REPORT_EVERY, - T, logger, stream_yaml, table_row_count, ) -class InappropriateGeneratorException(Exception): - """Exception thrown if a generator is requested that is not appropriate.""" - - -@functools.cache -def zipf_weights(size: int) -> list[float]: - """Get the weights of a Zipf distribution of a given size.""" - total = sum(map(lambda n: 1 / n, range(1, size + 1))) - return [1 / (n * total) for n in range(1, size + 1)] - - -def merge_with_constants( - xs: list[T], constants_at: dict[int, T] -) -> Generator[T, None, None]: - """ - Merge a list of items with other items that must be placed at certain indices. - - :param constants_at: A map of indices to objects that must be placed at - those indices. - :param xs: Items that fill in the gaps left by ``constants_at``. - :return: ``xs`` with ``constants_at`` inserted at the appropriate - points. If there are not enough elements in ``xs`` to fill in the gaps - in ``constants_at``, the elements of ``constants_at`` after the gap - are dropped. - """ - outi = 0 - xi = 0 - constant_count = len(constants_at) - while constant_count != 0: - if outi in constants_at: - yield constants_at[outi] - constant_count -= 1 - else: - if xi == len(xs): - return - yield xs[xi] - xi += 1 - outi += 1 - yield from xs[xi:] - - -class NothingToGenerateException(Exception): - """Exception thrown when no value can be generated.""" - - def __init__(self, message: str): - """Initialise the exception with a human-readable message.""" - super().__init__(message) - - -class DistributionGenerator: - """An object that can produce values from various distributions.""" - - root3 = math.sqrt(3) - - def __init__(self) -> None: - """Initialise the DistributionGenerator.""" - self.np_gen = np.random.default_rng() - - def uniform(self, low: float, high: float) -> float: - """ - Choose a value according to a uniform distribution. - - :param low: The lowest value that can be chosen. - :param high: The highest value that can be chosen. - :return: The output value. - """ - return random.uniform(float(low), float(high)) - - def uniform_ms(self, mean: float, sd: float) -> float: - """ - Choose a value according to a uniform distribution. - - :param mean: The mean of the output values. - :param sd: The standard deviation of the output values. - :return: The output value. - """ - m = float(mean) - h = self.root3 * float(sd) - return random.uniform(m - h, m + h) - - def normal(self, mean: float, sd: float) -> float: - """ - Choose a value according to a Gaussian (normal) distribution. - - :param mean: The mean of the output values. - :param sd: The standard deviation of the output values. - :return: The output value. - """ - return random.normalvariate(float(mean), float(sd)) - - def lognormal(self, logmean: float, logsd: float) -> float: - """ - Choose a value according to a lognormal distribution. - - :param logmean: The mean of the logs of the output values. - :param logsd: The standard deviation of the logs of the output values. - :return: The output value. - """ - return random.lognormvariate(float(logmean), float(logsd)) - - def choice_direct(self, a: list[T]) -> T: - """ - Choose a value with equal probability. - - :param a: The list of values to output. - :return: The chosen value. - """ - return random.choice(a) - - def choice(self, a: list[Mapping[str, T]]) -> T | None: - """ - Choose a value with equal probability. - - :param a: The list of values to output. Each element is a mapping with - a key ``value`` and the key is the value to return. - :return: The chosen value. - """ - return self.choice_direct(a).get("value", None) - - def zipf_choice_direct(self, a: list[T], n: int | None = None) -> T: - """ - Choose a value according to the Zipf distribution. - - The nth value (starting from 1) is chosen with a frequency - 1/n times as frequently as the first value is chosen. - - :param a: The list of values to output, most frequent first. - :return: The chosen value. - """ - if n is None: - n = len(a) - return random.choices(a, weights=zipf_weights(n))[0] - - def zipf_choice(self, a: list[Mapping[str, T]], n: int | None = None) -> T | None: - """ - Choose a value according to the Zipf distribution. - - The nth value (starting from 1) is chosen with a frequency - 1/n times as frequently as the first value is chosen. - - :param a: The list of rows to choose between, most frequent first. - Each element is a mapping with a key ``value`` and the key is the - value to return. - :return: The chosen value. - """ - c = self.zipf_choice_direct(a, n) - return c.get("value", None) - - def weighted_choice(self, a: list[dict[str, Any]]) -> Any: - """ - Choice weighted by the count in the original dataset. - - :param a: a list of dicts, each with a ``value`` key - holding the value to be returned and a ``count`` key holding the - number of that value found in the original dataset - :return: The chosen ``value``. - """ - vs = [] - counts = [] - for vc in a: - count = vc.get("count", 0) - if count: - counts.append(count) - vs.append(vc.get("value", None)) - c = random.choices(vs, weights=counts)[0] - return c - - def constant(self, value: T) -> T: - """Return the same value always.""" - return value - - def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: - """ - Return an array of values chosen from the given covariates. - - :param cov: Keys are ``rank``: The number of values to output; - ``mN``: The mean of variable ``N`` (where ``N`` is between 0 and - one less than ``rank``). ``cN_M`` (where 0 < ``N`` <= ``M`` < ``rank``): - the covariance between the ``N``th and the ``M``th variables. - :return: A numpy array of results. - """ - rank = int(cov["rank"]) - if rank == 0: - return np.empty(shape=(0,)) - mean = [float(cov[f"m{i}"]) for i in range(rank)] - covs = [ - [ - float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"]) - for i in range(rank) - ] - for j in range(rank) - ] - return self.np_gen.multivariate_normal(mean, covs) - - def _select_group(self, alts: list[dict[str, Any]]) -> Any: - """Choose one of the ``alts`` weighted by their ``"count"`` elements.""" - total = 0 - for alt in alts: - if alt["count"] < 0: - logger.warning( - "Alternative count is %d, but should not be negative", alt["count"] - ) - else: - total += alt["count"] - if total == 0: - raise NothingToGenerateException("No counts in any alternative") - choice = random.randrange(total) - for alt in alts: - choice -= alt["count"] - if choice < 0: - return alt - raise NothingToGenerateException( - "Internal error: ran out of choices in _select_group" - ) - - def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: - """ - Find all keys ``kN``, returning a dictionary of ``N: kNN``. - - This can be passed into ``merge_with_constants`` as the - ``constants_at`` argument. - """ - out: dict[int, Any] = {} - for k, v in result.items(): - if k.startswith("k") and k[1:].isnumeric(): - out[int(k[1:])] = v - return out - - PERMITTED_SUBGENS = { - "multivariate_lognormal", - "multivariate_normal", - "grouped_multivariate_lognormal", - "grouped_multivariate_normal", - "constant", - "weighted_choice", - "with_constants_at", - } - - def multivariate_normal(self, cov: dict[str, Any]) -> list[float]: - """ - Produce a list of values pulled from a multivariate distribution. - - :param cov: A dict with various keys: ``rank`` is the number of - output values, ``m0``, ``m1``, ... are the means of the - distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... - are the covariates, ``cN_M`` is the covariate of the ``N``th and - ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. - :return: list of ``rank`` floating point values - """ - out: list[float] = self.multivariate_normal_np(cov).tolist() - return out - - def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: - """ - Produce a list of values pulled from a multivariate distribution. - - :param cov: A dict with various keys: ``rank`` is the number of - output values, ``m0``, ``m1``, ... are the means of the - distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... - are the covariates, ``cN_M`` is the covariate of the ``N``th and - ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. These - are all the means and covariants of the logs of the data. - :return: list of ``rank`` floating point values - """ - out: list[Any] = np.exp(self.multivariate_normal_np(cov)).tolist() - return out - - def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: - """Produce a list of values pulled from a set of multivariate distributions.""" - cov = self._select_group(covs) - logger.debug("Multivariate normal group selected: %s", cov) - constants = self._find_constants(cov) - nums = self.multivariate_normal(cov) - return list(merge_with_constants(nums, constants)) - - def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: - """Produce a list of values pulled from a set of multivariate distributions.""" - cov = self._select_group(covs) - logger.debug("Multivariate lognormal group selected: %s", cov) - constants = self._find_constants(cov) - nums = np.exp(self.multivariate_normal_np(cov)).tolist() - return list(merge_with_constants(nums, constants)) - - def _check_generator_name(self, name: str) -> None: - if name not in self.PERMITTED_SUBGENS: - raise InappropriateGeneratorException( - f"{name} is not a permitted generator" - ) - - def alternatives( - self, - alternative_configs: list[dict[str, Any]], - counts: list[dict[str, int]] | None, - ) -> Any: - """ - Pick between other generators. - - :param alternative_configs: List of alternative generators. - Each alternative has the following keys: "count" -- a weight for - how often to use this alternative; "name" -- which generator - for this partition, for example "composite"; "params" -- the - parameters for this alternative. - :param counts: A list of weights for each alternative. If None, the - "count" value of each alternative is used. Each count is a dict - with a "count" key. - :return: list of values - """ - if counts is not None: - while True: - count = self._select_group(counts) - alt = alternative_configs[count["index"]] - name = alt["name"] - self._check_generator_name(name) - try: - return getattr(self, name)(**alt["params"]) - except NothingToGenerateException: - # Prevent this alternative from being chosen again - count["count"] = 0 - alt = self._select_group(alternative_configs) - name = alt["name"] - self._check_generator_name(name) - return getattr(self, name)(**alt["params"]) - - def with_constants_at( - self, constants_at: dict[int, T], subgen: str, params: dict[str, T] - ) -> list[T]: - """ - Insert constants into the results of a different generator. - - :param constants_at: A dictionary of positions and objects to insert - into the return list at those positions. - :param subgen: The name of the function to call to get the results - that will have the constants inserted into. - :param params: Keyword arguments to the ``subgen`` function. - :return: A list of results from calling ``subgen(**params)`` - with ``constants_at`` inserted in at the appropriate indices. - """ - if subgen not in self.PERMITTED_SUBGENS: - logger.error( - "subgenerator %s is not a valid name. Valid names are %s.", - subgen, - self.PERMITTED_SUBGENS, - ) - subout = getattr(self, subgen)(**params) - logger.debug("Merging constants %s", constants_at) - return list(merge_with_constants(subout, constants_at)) - - def truncated_string( - self, subgen_fn: Callable[..., list[T]], params: dict, length: int - ) -> list[T]: - """Call ``subgen_fn(**params)`` and truncate the results to ``length``.""" - result = subgen_fn(**params) - if result is None: - return None - return result[:length] - - class TableGenerator(ABC): """Abstract base class for table generator classes.""" diff --git a/datafaker/generators/base.py b/datafaker/generators/base.py index f2a1459..aba91b6 100644 --- a/datafaker/generators/base.py +++ b/datafaker/generators/base.py @@ -12,13 +12,13 @@ from sqlalchemy.types import Integer, Numeric, String, TypeEngine from typing_extensions import Self -from datafaker.base import DistributionGenerator +from datafaker.providers import DistributionProvider from datafaker.utils import logger NumericType = Union[int, float] -dist_gen = DistributionGenerator() +dist_gen = DistributionProvider() generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) diff --git a/datafaker/generators/mimesis.py b/datafaker/generators/mimesis.py index b300335..a0fa426 100644 --- a/datafaker/generators/mimesis.py +++ b/datafaker/generators/mimesis.py @@ -10,12 +10,12 @@ from datafaker.generators.base import ( Buckets, - DistributionGenerator, Generator, GeneratorError, GeneratorFactory, get_column_type, ) +from datafaker.providers import DistributionProvider NumericType = Union[int, float] @@ -23,7 +23,7 @@ # choice distribution to be infeasible? MAXIMUM_CHOICES = 500 -dist_gen = DistributionGenerator() +dist_gen = DistributionProvider() generic = mimesis.Generic(locale=mimesis.locales.Locale.EN_GB) diff --git a/datafaker/providers.py b/datafaker/providers.py index 75006c7..39a9d9a 100644 --- a/datafaker/providers.py +++ b/datafaker/providers.py @@ -1,13 +1,19 @@ """This module contains Mimesis Provider sub-classes.""" import datetime as dt +import functools +import math import random -from typing import Any, Optional, Union, cast +from collections.abc import Mapping +from typing import Any, Callable, Generator, Optional, Union, cast +import numpy as np from mimesis import Datetime, Text from mimesis.providers.base import BaseDataProvider, BaseProvider from sqlalchemy import Column, Connection from sqlalchemy.sql import func, functions, select +from datafaker.utils import T, logger + class ColumnValueProvider(BaseProvider): """A Mimesis provider of random values from the source database.""" @@ -215,3 +221,367 @@ class Meta: def null() -> None: """Return `None`.""" return None + + +class InappropriateGeneratorException(Exception): + """Exception thrown if a generator is requested that is not appropriate.""" + + +class NothingToGenerateException(Exception): + """Exception thrown when no value can be generated.""" + + def __init__(self, message: str): + """Initialise the exception with a human-readable message.""" + super().__init__(message) + + +@functools.cache +def zipf_weights(size: int) -> list[float]: + """Get the weights of a Zipf distribution of a given size.""" + total = sum(map(lambda n: 1 / n, range(1, size + 1))) + return [1 / (n * total) for n in range(1, size + 1)] + + +def merge_with_constants( + xs: list[T], constants_at: dict[int, T] +) -> Generator[T, None, None]: + """ + Merge a list of items with other items that must be placed at certain indices. + + :param constants_at: A map of indices to objects that must be placed at + those indices. + :param xs: Items that fill in the gaps left by ``constants_at``. + :return: ``xs`` with ``constants_at`` inserted at the appropriate + points. If there are not enough elements in ``xs`` to fill in the gaps + in ``constants_at``, the elements of ``constants_at`` after the gap + are dropped. + """ + outi = 0 + xi = 0 + constant_count = len(constants_at) + while constant_count != 0: + if outi in constants_at: + yield constants_at[outi] + constant_count -= 1 + else: + if xi == len(xs): + return + yield xs[xi] + xi += 1 + outi += 1 + yield from xs[xi:] + + +class DistributionProvider(BaseProvider): + """A Mimesis provider for various distributions.""" + + class Meta: + """Meta-class for various distributions.""" + + name = "distribution_provider" + + root3 = math.sqrt(3) + + def __init__(self, *, seed: int | None = None, **kwargs: Any) -> None: + """Initialise a DistributionProvider.""" + super().__init__(seed=seed, **kwargs) + np_seed = seed if isinstance(seed, int) else None + self.np_gen = np.random.default_rng(seed=np_seed) + + def uniform(self, low: float, high: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param low: The lowest value that can be chosen. + :param high: The highest value that can be chosen. + :return: The output value. + """ + return random.uniform(float(low), float(high)) + + def uniform_ms(self, mean: float, sd: float) -> float: + """ + Choose a value according to a uniform distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ + m = float(mean) + h = self.root3 * float(sd) + return random.uniform(m - h, m + h) + + def normal(self, mean: float, sd: float) -> float: + """ + Choose a value according to a Gaussian (normal) distribution. + + :param mean: The mean of the output values. + :param sd: The standard deviation of the output values. + :return: The output value. + """ + return random.normalvariate(float(mean), float(sd)) + + def lognormal(self, logmean: float, logsd: float) -> float: + """ + Choose a value according to a lognormal distribution. + + :param logmean: The mean of the logs of the output values. + :param logsd: The standard deviation of the logs of the output values. + :return: The output value. + """ + return random.lognormvariate(float(logmean), float(logsd)) + + def choice_direct(self, a: list[T]) -> T: + """ + Choose a value with equal probability. + + :param a: The list of values to output. + :return: The chosen value. + """ + return random.choice(a) + + def choice(self, a: list[Mapping[str, T]]) -> T | None: + """ + Choose a value with equal probability. + + :param a: The list of values to output. Each element is a mapping with + a key ``value`` and the key is the value to return. + :return: The chosen value. + """ + return self.choice_direct(a).get("value", None) + + def zipf_choice_direct(self, a: list[T], n: int | None = None) -> T: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of values to output, most frequent first. + :return: The chosen value. + """ + if n is None: + n = len(a) + return random.choices(a, weights=zipf_weights(n))[0] + + def zipf_choice(self, a: list[Mapping[str, T]], n: int | None = None) -> T | None: + """ + Choose a value according to the Zipf distribution. + + The nth value (starting from 1) is chosen with a frequency + 1/n times as frequently as the first value is chosen. + + :param a: The list of rows to choose between, most frequent first. + Each element is a mapping with a key ``value`` and the key is the + value to return. + :return: The chosen value. + """ + c = self.zipf_choice_direct(a, n) + return c.get("value", None) + + def weighted_choice(self, a: list[dict[str, Any]]) -> Any: + """ + Choice weighted by the count in the original dataset. + + :param a: a list of dicts, each with a ``value`` key + holding the value to be returned and a ``count`` key holding the + number of that value found in the original dataset + :return: The chosen ``value``. + """ + vs = [] + counts = [] + for vc in a: + count = vc.get("count", 0) + if count: + counts.append(count) + vs.append(vc.get("value", None)) + c = random.choices(vs, weights=counts)[0] + return c + + def constant(self, value: T) -> T: + """Return the same value always.""" + return value + + def multivariate_normal_np(self, cov: dict[str, Any]) -> np.typing.NDArray: + """ + Return an array of values chosen from the given covariates. + + :param cov: Keys are ``rank``: The number of values to output; + ``mN``: The mean of variable ``N`` (where ``N`` is between 0 and + one less than ``rank``). ``cN_M`` (where 0 < ``N`` <= ``M`` < ``rank``): + the covariance between the ``N``th and the ``M``th variables. + :return: A numpy array of results. + """ + rank = int(cov["rank"]) + if rank == 0: + return np.empty(shape=(0,)) + mean = [float(cov[f"m{i}"]) for i in range(rank)] + covs = [ + [ + float(cov[f"c{i}_{j}"] if i <= j else cov[f"c{j}_{i}"]) + for i in range(rank) + ] + for j in range(rank) + ] + return self.np_gen.multivariate_normal(mean, covs) + + def _select_group(self, alts: list[dict[str, Any]]) -> Any: + """Choose one of the ``alts`` weighted by their ``"count"`` elements.""" + total = 0 + for alt in alts: + if alt["count"] < 0: + logger.warning( + "Alternative count is %d, but should not be negative", alt["count"] + ) + else: + total += alt["count"] + if total == 0: + raise NothingToGenerateException("No counts in any alternative") + choice = random.randrange(total) + for alt in alts: + choice -= alt["count"] + if choice < 0: + return alt + raise NothingToGenerateException( + "Internal error: ran out of choices in _select_group" + ) + + def _find_constants(self, result: dict[str, Any]) -> dict[int, Any]: + """ + Find all keys ``kN``, returning a dictionary of ``N: kNN``. + + This can be passed into ``merge_with_constants`` as the + ``constants_at`` argument. + """ + out: dict[int, Any] = {} + for k, v in result.items(): + if k.startswith("k") and k[1:].isnumeric(): + out[int(k[1:])] = v + return out + + PERMITTED_SUBGENS = { + "multivariate_lognormal", + "multivariate_normal", + "grouped_multivariate_lognormal", + "grouped_multivariate_normal", + "constant", + "weighted_choice", + "with_constants_at", + } + + def multivariate_normal(self, cov: dict[str, Any]) -> list[float]: + """ + Produce a list of values pulled from a multivariate distribution. + + :param cov: A dict with various keys: ``rank`` is the number of + output values, ``m0``, ``m1``, ... are the means of the + distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... + are the covariates, ``cN_M`` is the covariate of the ``N``th and + ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. + :return: list of ``rank`` floating point values + """ + out: list[float] = self.multivariate_normal_np(cov).tolist() + return out + + def multivariate_lognormal(self, cov: dict[str, Any]) -> list[float]: + """ + Produce a list of values pulled from a multivariate distribution. + + :param cov: A dict with various keys: ``rank`` is the number of + output values, ``m0``, ``m1``, ... are the means of the + distributions (``rank`` of them). ``c0_0``, ``c0_1``, ``c1_1``, ... + are the covariates, ``cN_M`` is the covariate of the ``N``th and + ``M``th varaibles, with 0 <= ``N`` <= ``M`` < ``rank``. These + are all the means and covariants of the logs of the data. + :return: list of ``rank`` floating point values + """ + out: list[Any] = np.exp(self.multivariate_normal_np(cov)).tolist() + return out + + def grouped_multivariate_normal(self, covs: list[dict[str, Any]]) -> list[Any]: + """Produce a list of values pulled from a set of multivariate distributions.""" + cov = self._select_group(covs) + logger.debug("Multivariate normal group selected: %s", cov) + constants = self._find_constants(cov) + nums = self.multivariate_normal(cov) + return list(merge_with_constants(nums, constants)) + + def grouped_multivariate_lognormal(self, covs: list[dict[str, Any]]) -> list[Any]: + """Produce a list of values pulled from a set of multivariate distributions.""" + cov = self._select_group(covs) + logger.debug("Multivariate lognormal group selected: %s", cov) + constants = self._find_constants(cov) + nums = np.exp(self.multivariate_normal_np(cov)).tolist() + return list(merge_with_constants(nums, constants)) + + def _check_generator_name(self, name: str) -> None: + if name not in self.PERMITTED_SUBGENS: + raise InappropriateGeneratorException( + f"{name} is not a permitted generator" + ) + + def alternatives( + self, + alternative_configs: list[dict[str, Any]], + counts: list[dict[str, int]] | None, + ) -> Any: + """ + Pick between other generators. + + :param alternative_configs: List of alternative generators. + Each alternative has the following keys: "count" -- a weight for + how often to use this alternative; "name" -- which generator + for this partition, for example "composite"; "params" -- the + parameters for this alternative. + :param counts: A list of weights for each alternative. If None, the + "count" value of each alternative is used. Each count is a dict + with a "count" key. + :return: list of values + """ + if counts is not None: + while True: + count = self._select_group(counts) + alt = alternative_configs[count["index"]] + name = alt["name"] + self._check_generator_name(name) + try: + return getattr(self, name)(**alt["params"]) + except NothingToGenerateException: + # Prevent this alternative from being chosen again + count["count"] = 0 + alt = self._select_group(alternative_configs) + name = alt["name"] + self._check_generator_name(name) + return getattr(self, name)(**alt["params"]) + + def with_constants_at( + self, constants_at: dict[int, T], subgen: str, params: dict[str, T] + ) -> list[T]: + """ + Insert constants into the results of a different generator. + + :param constants_at: A dictionary of positions and objects to insert + into the return list at those positions. + :param subgen: The name of the function to call to get the results + that will have the constants inserted into. + :param params: Keyword arguments to the ``subgen`` function. + :return: A list of results from calling ``subgen(**params)`` + with ``constants_at`` inserted in at the appropriate indices. + """ + if subgen not in self.PERMITTED_SUBGENS: + logger.error( + "subgenerator %s is not a valid name. Valid names are %s.", + subgen, + self.PERMITTED_SUBGENS, + ) + subout = getattr(self, subgen)(**params) + logger.debug("Merging constants %s", constants_at) + return list(merge_with_constants(subout, constants_at)) + + def truncated_string( + self, subgen_fn: Callable[..., list[T]], params: dict, length: int + ) -> list[T]: + """Call ``subgen_fn(**params)`` and truncate the results to ``length``.""" + result = subgen_fn(**params) + if result is None: + return None + return result[:length] diff --git a/datafaker/templates/df.py.j2 b/datafaker/templates/df.py.j2 index 3a2da41..0e2616e 100644 --- a/datafaker/templates/df.py.j2 +++ b/datafaker/templates/df.py.j2 @@ -3,12 +3,13 @@ from mimesis import Generic, Numeric, Person from mimesis.locales import Locale import sqlalchemy import sys -from datafaker.base import FileUploader, TableGenerator, DistributionGenerator, ColumnPresence +from datafaker.base import FileUploader, TableGenerator, ColumnPresence +from datafaker.providers import DistributionProvider generic = Generic(locale=Locale.EN_GB) numeric = Numeric() person = Person() -dist_gen = DistributionGenerator() +dist_gen = DistributionProvider() column_presence = ColumnPresence() sys.path.append("") From 3d9e7f0098acf343d12b41d7201c02d8fa3b6bb8 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Wed, 5 Nov 2025 12:46:57 +0000 Subject: [PATCH 35/35] Answering a PR comment --- datafaker/generators/__init__.py | 2 ++ docs/source/docker.rst | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/datafaker/generators/__init__.py b/datafaker/generators/__init__.py index 9ed619b..650c8ba 100644 --- a/datafaker/generators/__init__.py +++ b/datafaker/generators/__init__.py @@ -28,6 +28,8 @@ ) +# Using a cache instead of just initializing an object to avoid +# startup time being spent when it isn't needed. @lru_cache(1) def everything_factory() -> GeneratorFactory: """Get a factory that encapsulates all the other factories.""" diff --git a/docs/source/docker.rst b/docs/source/docker.rst index 8954fc6..62fcfe0 100644 --- a/docs/source/docker.rst +++ b/docs/source/docker.rst @@ -32,7 +32,7 @@ computer. Running the image in this way will give you a command prompt from which datafaker can be called. Tab completion can be used. For example, if -you type ``sq mat`` you will see +you type ``dataf mat`` you will see ``datafaker make-tables``; although you might have to wait a second or two after some of the ```` key presses for the completed text to appear. Tab completion can also be used for command options such