Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 897ce8f

Browse files
authored
Merge pull request #308 from datafold/nov22_sqeleton_refactor
Nov22 sqeleton refactor
2 parents 365e08e + 40fe8cb commit 897ce8f

37 files changed

+228
-223
lines changed

data_diff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .tracking import disable_tracking
44
from .databases import connect
5-
from .sqeleton.databases import DbKey, DbTime, DbPath
5+
from .sqeleton.abcs import DbKey, DbTime, DbPath
66
from .diff_tables import Algorithm
77
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
88
from .joindiff_tables import JoinDiffer

data_diff/__main__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
1515
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
1616
from .table_segment import TableSegment
17-
from .sqeleton.databases import create_schema
17+
from .sqeleton.schema import create_schema
1818
from .databases import connect
1919
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
2020
from .config import apply_config_from_file
@@ -54,10 +54,10 @@ def diff_schemas(table1, table2, schema1, schema2, columns):
5454
diffs = []
5555

5656
if c not in schema1:
57-
cols = ', '.join(schema1)
57+
cols = ", ".join(schema1)
5858
raise ValueError(f"Column '{c}' not found in table 1, named '{table1}'. Columns: {cols}")
5959
if c not in schema2:
60-
cols = ', '.join(schema1)
60+
cols = ", ".join(schema1)
6161
raise ValueError(f"Column '{c}' not found in table 2, named '{table2}'. Columns: {cols}")
6262

6363
col1 = schema1[c]
@@ -216,7 +216,6 @@ def main(conf, run, **kw):
216216
raise
217217

218218

219-
220219
def _main(
221220
database1,
222221
table1,

data_diff/databases/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from data_diff.sqeleton.databases import AbstractMixin_MD5, AbstractMixin_NormalizeValue
1+
from data_diff.sqeleton.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
22

33

44
class DatadiffDialect(AbstractMixin_MD5, AbstractMixin_NormalizeValue):

data_diff/diff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .thread_utils import ThreadedYielder
1919
from .table_segment import TableSegment
2020
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
21-
from .sqeleton.databases import IKey
21+
from .sqeleton.abcs import IKey
2222

2323
logger = getLogger(__name__)
2424

data_diff/hashdiff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .utils import safezip
1313
from .thread_utils import ThreadedYielder
14-
from .sqeleton.databases import ColType_UUID, NumericType, PrecisionType, StringType
14+
from .sqeleton.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType
1515
from .table_segment import TableSegment
1616

1717
from .diff_tables import TableDiffer

data_diff/info_tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def update_from_children(self, child_infos):
3131
self.is_diff = any(c.is_diff for c in child_infos)
3232

3333
self.rowcounts = {
34-
1: sum(c.rowcounts[1] for c in child_infos),
35-
2: sum(c.rowcounts[2] for c in child_infos),
34+
1: sum(c.rowcounts[1] for c in child_infos if c.rowcounts),
35+
2: sum(c.rowcounts[2] for c in child_infos if c.rowcounts),
3636
}
3737

3838

data_diff/joindiff_tables.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
from runtype import dataclass
1212

13-
from .sqeleton.databases import Database, DbPath, NumericType, MySQL, BigQuery, Presto, Oracle, Snowflake
13+
from .sqeleton.databases import Database, MySQL, BigQuery, Presto, Oracle, Snowflake
14+
from .sqeleton.abcs.database_types import DbPath, NumericType
1415
from .sqeleton.queries import table, sum_, min_, max_, avg
1516
from .sqeleton.queries.api import and_, if_, or_, outerjoin, leftjoin, rightjoin, this, ITable
16-
from .sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath
17+
from .sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code
1718
from .sqeleton.queries.compiler import Compiler
1819
from .sqeleton.queries.extras import NormalizeAsString
1920

@@ -332,7 +333,7 @@ def exclusive_rows(expr):
332333
c = Compiler(db)
333334
name = c.new_unique_table_name("temp_table")
334335
exclusive_rows = table(name, schema=expr.source_table.schema)
335-
yield create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit))
336+
yield Code(create_temp_table(c, exclusive_rows, expr.limit(self.table_write_limit)))
336337

337338
count = yield exclusive_rows.count()
338339
self.stats["exclusive_count"] = self.stats.get("exclusive_count", 0) + count[0][0]
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .database_types import AbstractDatabase, AbstractDialect, DbKey, DbPath, DbTime, IKey
2+
from .compiler import AbstractCompiler, Compilable
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Any, Dict
2+
from abc import ABC, abstractmethod
3+
4+
5+
class AbstractCompiler(ABC):
6+
@abstractmethod
7+
def compile(self, elem: Any, params: Dict[str, Any] = None) -> str:
8+
...
9+
10+
11+
class Compilable(ABC):
12+
@abstractmethod
13+
def compile(self, c: AbstractCompiler) -> str:
14+
...

data_diff/sqeleton/databases/database_types.py renamed to data_diff/sqeleton/abcs/database_types.py

Lines changed: 1 addition & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
1-
import logging
21
import decimal
32
from abc import ABC, abstractmethod
43
from typing import Sequence, Optional, Tuple, Union, Dict, List
54
from datetime import datetime
65

76
from runtype import dataclass
87

9-
from ..utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict, ArithAlphanumeric, ArithUUID
8+
from ..utils import ArithAlphanumeric, ArithUUID
109

1110

1211
DbPath = Tuple[str, ...]
1312
DbKey = Union[int, str, bytes, ArithUUID, ArithAlphanumeric]
1413
DbTime = datetime
1514

16-
logger = logging.getLogger("databases")
17-
1815

1916
class ColType:
2017
supported = True
@@ -214,94 +211,6 @@ def parse_type(
214211
"Parse type info as returned by the database"
215212

216213

217-
class AbstractMixin_NormalizeValue(ABC):
218-
@abstractmethod
219-
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
220-
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
221-
222-
The returned expression must accept any SQL datetime/timestamp, and return a string.
223-
224-
Date format: ``YYYY-MM-DD HH:mm:SS.FFFFFF``
225-
226-
Precision of dates should be rounded up/down according to coltype.rounds
227-
"""
228-
...
229-
230-
@abstractmethod
231-
def normalize_number(self, value: str, coltype: FractionalType) -> str:
232-
"""Creates an SQL expression, that converts 'value' to a normalized number.
233-
234-
The returned expression must accept any SQL int/numeric/float, and return a string.
235-
236-
Floats/Decimals are expected in the format
237-
"I.P"
238-
239-
Where I is the integer part of the number (as many digits as necessary),
240-
and must be at least one digit (0).
241-
P is the fractional digits, the amount of which is specified with
242-
coltype.precision. Trailing zeroes may be necessary.
243-
If P is 0, the dot is omitted.
244-
245-
Note: We use 'precision' differently than most databases. For decimals,
246-
it's the same as ``numeric_scale``, and for floats, who use binary precision,
247-
it can be calculated as ``log10(2**numeric_precision)``.
248-
"""
249-
...
250-
251-
@abstractmethod
252-
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
253-
"""Creates an SQL expression, that converts 'value' to a normalized uuid.
254-
255-
i.e. just makes sure there is no trailing whitespace.
256-
"""
257-
...
258-
259-
def normalize_boolean(self, value: str, coltype: Boolean) -> str:
260-
"""Creates an SQL expression, that converts 'value' to either '0' or '1'."""
261-
return self.to_string(value)
262-
263-
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
264-
"""Creates an SQL expression, that strips uuids of artifacts like whitespace."""
265-
if isinstance(coltype, String_UUID):
266-
return f"TRIM({value})"
267-
return self.to_string(value)
268-
269-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
270-
"""Creates an SQL expression, that converts 'value' to a normalized representation.
271-
272-
The returned expression must accept any SQL value, and return a string.
273-
274-
The default implementation dispatches to a method according to `coltype`:
275-
276-
::
277-
278-
TemporalType -> normalize_timestamp()
279-
FractionalType -> normalize_number()
280-
*else* -> to_string()
281-
282-
(`Integer` falls in the *else* category)
283-
284-
"""
285-
if isinstance(coltype, TemporalType):
286-
return self.normalize_timestamp(value, coltype)
287-
elif isinstance(coltype, FractionalType):
288-
return self.normalize_number(value, coltype)
289-
elif isinstance(coltype, ColType_UUID):
290-
return self.normalize_uuid(value, coltype)
291-
elif isinstance(coltype, Boolean):
292-
return self.normalize_boolean(value, coltype)
293-
return self.to_string(value)
294-
295-
296-
class AbstractMixin_MD5(ABC):
297-
"""Dialect-dependent query expressions, that are specific to data-diff"""
298-
299-
@abstractmethod
300-
def md5_as_int(self, s: str) -> str:
301-
"Provide SQL for computing md5 and returning an int"
302-
...
303-
304-
305214
class AbstractDatabase:
306215
@property
307216
@abstractmethod
@@ -374,18 +283,3 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
374283
@abstractmethod
375284
def is_autocommit(self) -> bool:
376285
"Return whether the database autocommits changes. When false, COMMIT statements are skipped."
377-
378-
379-
Schema = CaseAwareMapping
380-
381-
382-
def create_schema(db: AbstractDatabase, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
383-
logger.debug(f"[{db.name}] Schema = {schema}")
384-
385-
if case_sensitive:
386-
return CaseSensitiveDict(schema)
387-
388-
if len({k.lower() for k in schema}) < len(schema):
389-
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
390-
logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).")
391-
return CaseInsensitiveDict(schema)

0 commit comments

Comments
 (0)