Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 07ebc04

Browse files
committed
Now samples all TEXT columns at once; Added utils.py
1 parent cd9c695 commit 07ebc04

File tree

4 files changed

+76
-55
lines changed

4 files changed

+76
-55
lines changed

data_diff/databases/base.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import threading
99
from abc import abstractmethod
1010

11+
from data_diff.utils import is_uuid, safezip
1112
from .database_types import (
1213
ColType_UUID,
1314
AbstractDatabase,
@@ -29,14 +30,6 @@ def parse_table_name(t):
2930
return tuple(t.split("."))
3031

3132

32-
def is_uuid(u):
33-
try:
34-
UUID(u)
35-
except ValueError:
36-
return False
37-
return True
38-
39-
4033
def import_helper(package: str = None, text=""):
4134
def dec(f):
4235
@wraps(f)
@@ -172,18 +165,7 @@ def _parse_type(
172165
)
173166

174167
elif issubclass(cls, Text):
175-
samples = self.query(Select([col_name], TableName(table_path), limit=16), List[str])
176-
uuid_samples = list(filter(is_uuid, samples))
177-
178-
if uuid_samples:
179-
if len(uuid_samples) != len(samples):
180-
logger.warning(
181-
f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
182-
)
183-
else:
184-
return ColType_UUID()
185-
186-
return Text()
168+
return cls()
187169

188170
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
189171

@@ -204,8 +186,31 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
204186
accept = {i.lower() for i in filter_columns}
205187
rows = [r for r in rows if r[0].lower() in accept]
206188

189+
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in rows}
190+
191+
self._refine_coltypes(path, col_dict)
192+
207193
# Return a dict of form {name: type} after normalization
208-
return {row[0]: self._parse_type(path, *row) for row in rows}
194+
return col_dict
195+
196+
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
197+
"Refine the types in the column dict, by querying the database for a sample of their values"
198+
199+
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
200+
201+
samples_by_row = self.query(Select(text_columns, TableName(table_path), limit=16), list)
202+
samples_by_col = list(zip(*samples_by_row))
203+
for col_name, samples in safezip(text_columns, samples_by_col):
204+
uuid_samples = list(filter(is_uuid, samples))
205+
206+
if uuid_samples:
207+
if len(uuid_samples) != len(samples):
208+
logger.warning(
209+
f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
210+
)
211+
else:
212+
assert col_name in col_dict
213+
col_dict[col_name] = ColType_UUID()
209214

210215
# @lru_cache()
211216
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:

data_diff/databases/database_types.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,10 @@
1-
from uuid import UUID
21
from abc import ABC, abstractmethod
32
from typing import Sequence, Optional, Tuple, Union, Dict, Any
43
from datetime import datetime
54

65
from runtype import dataclass
76

8-
9-
class ArithUUID(UUID):
10-
"A UUID that supports basic arithmetic (add, sub)"
11-
12-
def __add__(self, other: Union[UUID, int]):
13-
if isinstance(other, int):
14-
return type(self)(int=self.int + other)
15-
return NotImplemented
16-
17-
def __sub__(self, other: Union[UUID, int]):
18-
if isinstance(other, int):
19-
return type(self)(int=self.int - other)
20-
elif isinstance(other, UUID):
21-
return self.int - other.int
22-
return NotImplemented
23-
24-
def __int__(self):
25-
return self.int
7+
from data_diff.utils import ArithUUID
268

279

2810
DbPath = Tuple[str, ...]
@@ -31,6 +13,7 @@ def __int__(self):
3113

3214

3315
class ColType:
16+
supported = True
3417
pass
3518

3619

@@ -87,8 +70,9 @@ class ColType_UUID(StringType, IKey):
8770
python_type = ArithUUID
8871

8972

73+
@dataclass
9074
class Text(StringType):
91-
pass
75+
supported = False
9276

9377

9478
@dataclass
@@ -104,6 +88,8 @@ def __post_init__(self):
10488
class UnknownColType(ColType):
10589
text: str
10690

91+
supported = False
92+
10793

10894
class AbstractDatabase(ABC):
10995
@abstractmethod

data_diff/diff_tables.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from runtype import dataclass
1313

1414
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max, Value
15+
from .utils import safezip, split_space
1516
from .databases.base import Database
1617
from .databases.database_types import (
1718
ArithUUID,
@@ -31,17 +32,6 @@
3132
DEFAULT_BISECTION_FACTOR = 32
3233

3334

34-
def safezip(*args):
35-
"zip but makes sure all sequences are the same length"
36-
assert len(set(map(len, args))) == 1
37-
return zip(*args)
38-
39-
40-
def split_space(start, end, count):
41-
size = end - start
42-
return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1]
43-
44-
4535
@dataclass(frozen=False)
4636
class TableSegment:
4737
"""Signifies a segment of rows (and selected columns) within a table
@@ -359,9 +349,9 @@ def _validate_and_adjust_columns(self, table1, table2):
359349
for t in [table1, table2]:
360350
for c in t._relevant_columns:
361351
ctype = t._schema[c]
362-
if isinstance(ctype, UnknownColType):
363-
logger.warn(
364-
f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. "
352+
if not ctype.supported:
353+
logger.warning(
354+
f"[{t.database.name}] Column '{c}' of type '{ctype}' has no compatibility handling. "
365355
"If encoding/formatting differs between databases, it may result in false positives."
366356
)
367357

data_diff/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Sequence, Optional, Tuple, Union, Dict, Any
2+
from uuid import UUID
3+
4+
5+
def safezip(*args):
6+
"zip but makes sure all sequences are the same length"
7+
assert len(set(map(len, args))) == 1
8+
return zip(*args)
9+
10+
11+
def split_space(start, end, count):
12+
size = end - start
13+
return list(range(start, end, (size + 1) // (count + 1)))[1 : count + 1]
14+
15+
16+
class ArithUUID(UUID):
17+
"A UUID that supports basic arithmetic (add, sub)"
18+
19+
def __add__(self, other: Union[UUID, int]):
20+
if isinstance(other, int):
21+
return type(self)(int=self.int + other)
22+
return NotImplemented
23+
24+
def __sub__(self, other: Union[UUID, int]):
25+
if isinstance(other, int):
26+
return type(self)(int=self.int - other)
27+
elif isinstance(other, UUID):
28+
return self.int - other.int
29+
return NotImplemented
30+
31+
def __int__(self):
32+
return self.int
33+
34+
35+
def is_uuid(u):
36+
try:
37+
UUID(u)
38+
except ValueError:
39+
return False
40+
return True

0 commit comments

Comments
 (0)