Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5812354

Browse files
committed
UUID field support, initial (implemented only for MySQL)
1 parent 94ad17a commit 5812354

File tree

9 files changed

+123
-19
lines changed

9 files changed

+123
-19
lines changed

data_diff/databases/base.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
from uuid import UUID
12
import math
23
import sys
34
import logging
4-
from typing import Dict, Tuple, Optional, Sequence, Type
5+
from typing import Dict, Tuple, Optional, Sequence, Type, List
56
from functools import lru_cache, wraps
67
from concurrent.futures import ThreadPoolExecutor
78
import threading
89
from abc import abstractmethod
910

1011
from .database_types import (
12+
ColType_UUID,
1113
AbstractDatabase,
1214
ColType,
1315
Integer,
@@ -16,8 +18,9 @@
1618
PrecisionType,
1719
TemporalType,
1820
UnknownColType,
21+
Text,
1922
)
20-
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select
23+
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName
2124

2225
logger = logging.getLogger("database")
2326

@@ -26,6 +29,14 @@ def parse_table_name(t):
2629
return tuple(t.split("."))
2730

2831

32+
def is_uuid(u):
33+
try:
34+
UUID(u)
35+
except ValueError:
36+
return False
37+
return True
38+
39+
2940
def import_helper(package: str = None, text=""):
3041
def dec(f):
3142
@wraps(f)
@@ -102,7 +113,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
102113
assert len(res) == 1, (sql_code, res)
103114
return res[0]
104115
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
105-
if res_type.__args__ == (int,):
116+
if res_type.__args__ == (int,) or res_type.__args__ == (str,):
106117
return [_one(row) for row in res]
107118
elif res_type.__args__ == (Tuple,):
108119
return [tuple(row) for row in res]
@@ -123,6 +134,7 @@ def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
123134

124135
def _parse_type(
125136
self,
137+
table_path: DbPath,
126138
col_name: str,
127139
type_repr: str,
128140
datetime_precision: int = None,
@@ -147,7 +159,7 @@ def _parse_type(
147159
elif issubclass(cls, Decimal):
148160
if numeric_scale is None:
149161
raise ValueError(
150-
f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}."
162+
f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
151163
)
152164
return cls(precision=numeric_scale)
153165

@@ -159,6 +171,20 @@ def _parse_type(
159171
)
160172
)
161173

174+
elif issubclass(cls, Text):
175+
samples = self.query(Select([col_name], TableName(table_path), limit=16), List[str])
176+
uuid_samples = list(filter(is_uuid, samples))
177+
178+
if uuid_samples:
179+
if len(uuid_samples) != len(samples):
180+
logger.warning(
181+
f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
182+
)
183+
else:
184+
return ColType_UUID()
185+
186+
return Text()
187+
162188
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
163189

164190
def select_table_schema(self, path: DbPath) -> str:
@@ -179,7 +205,7 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
179205
rows = [r for r in rows if r[0].lower() in accept]
180206

181207
# Return a dict of form {name: type} after normalization
182-
return {row[0]: self._parse_type(*row) for row in rows}
208+
return {row[0]: self._parse_type(path, *row) for row in rows}
183209

184210
# @lru_cache()
185211
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
@@ -233,6 +259,12 @@ def create_connection(self):
233259
def close(self):
234260
self._queue.shutdown()
235261

262+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
263+
if offset:
264+
raise NotImplementedError("No support for OFFSET in query")
265+
266+
return f"LIMIT {limit}"
267+
236268

237269
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower
238270
MD5_HEXDIGITS = 32

data_diff/databases/database_types.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
1+
from uuid import UUID
12
from abc import ABC, abstractmethod
2-
from typing import Sequence, Optional, Tuple, Union, Dict
3+
from typing import Sequence, Optional, Tuple, Union, Dict, Any
34
from datetime import datetime
45

56
from runtype import dataclass
67

8+
9+
class ArithUUID(UUID):
10+
"A UUID that supports basic arithmetic (add, sub)"
11+
12+
def __add__(self, other: Union[UUID, int]):
13+
if isinstance(other, int):
14+
return type(self)(int=self.int + other)
15+
return NotImplemented
16+
17+
def __sub__(self, other: Union[UUID, int]):
18+
if isinstance(other, int):
19+
return type(self)(int=self.int - other)
20+
elif isinstance(other, UUID):
21+
return self.int - other.int
22+
return NotImplemented
23+
24+
def __int__(self):
25+
return self.int
26+
27+
728
DbPath = Tuple[str, ...]
8-
DbKey = Union[int, str, bytes]
29+
DbKey = Union[int, str, bytes, ArithUUID]
930
DbTime = datetime
1031

1132

@@ -57,12 +78,24 @@ class StringType(ColType):
5778
pass
5879

5980

60-
class UUID(StringType):
81+
class IKey(ABC):
82+
"Interface for ColType, for using a column as a key in data-diff"
83+
python_type: type
84+
85+
86+
class ColType_UUID(StringType, IKey):
87+
python_type = ArithUUID
88+
89+
90+
class Text(StringType):
6191
pass
6292

6393

6494
@dataclass
65-
class Integer(NumericType):
95+
class Integer(NumericType, IKey):
96+
precision: int = 0
97+
python_type: type = int
98+
6699
def __post_init__(self):
67100
assert self.precision == 0
68101

@@ -88,6 +121,10 @@ def md5_to_int(self, s: str) -> str:
88121
"Provide SQL for computing md5 and returning an int"
89122
...
90123

124+
@abstractmethod
125+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
126+
...
127+
91128
@abstractmethod
92129
def _query(self, sql_code: str) -> list:
93130
"Send query to database and return result"

data_diff/databases/mysql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class MySQL(ThreadedDatabase):
2020
"float": Float,
2121
"decimal": Decimal,
2222
"int": Integer,
23+
# Text
24+
"varchar": Text,
2325
}
2426
ROUNDS_ON_PREC_LOSS = True
2527

data_diff/databases/oracle.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,9 @@ def _parse_type(
9090
)
9191

9292
return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)
93+
94+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
95+
if offset:
96+
raise NotImplementedError("No support for OFFSET in query")
97+
98+
return f"FETCH NEXT {limit} ROWS ONLY"

data_diff/databases/postgresql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class PostgreSQL(ThreadedDatabase):
2525
"integer": Integer,
2626
"numeric": Decimal,
2727
"bigint": Integer,
28+
# Text
29+
"varchar": Text,
30+
"text": Text,
2831
}
2932
ROUNDS_ON_PREC_LOSS = True
3033

data_diff/databases/presto.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ def select_table_schema(self, path: DbPath) -> str:
8181
)
8282

8383
def _parse_type(
84-
self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None
84+
self,
85+
table_path: DbPath,
86+
col_name: str,
87+
type_repr: str,
88+
datetime_precision: int = None,
89+
numeric_precision: int = None,
8590
) -> ColType:
8691
timestamp_regexps = {
8792
r"timestamp\((\d)\)": Timestamp,
@@ -103,4 +108,4 @@ def _parse_type(
103108
prec, scale = map(int, m.groups())
104109
return n_cls(scale)
105110

106-
return super()._parse_type(type_repr)
111+
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)

data_diff/diff_tables.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
from runtype import dataclass
1313

14-
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max
14+
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max, Value
1515
from .databases.base import Database
1616
from .databases.database_types import (
17+
ArithUUID,
1718
NumericType,
1819
PrecisionType,
1920
UnknownColType,
@@ -121,9 +122,9 @@ def with_schema(self) -> "TableSegment":
121122

122123
def _make_key_range(self):
123124
if self.min_key is not None:
124-
yield Compare("<=", str(self.min_key), self._key_column)
125+
yield Compare("<=", Value(self.min_key), self._key_column)
125126
if self.max_key is not None:
126-
yield Compare("<", self._key_column, str(self.max_key))
127+
yield Compare("<", self._key_column, Value(self.max_key))
127128

128129
def _make_update_range(self):
129130
if self.min_update is not None:
@@ -152,6 +153,11 @@ def get_values(self) -> list:
152153
def choose_checkpoints(self, count: int) -> List[DbKey]:
153154
"Suggests a bunch of evenly-spaced checkpoints to split by (not including start, end)"
154155
assert self.is_bounded
156+
if isinstance(self.min_key, ArithUUID):
157+
checkpoints = split_space(self.min_key.int, self.max_key.int, count)
158+
assert isinstance(self.max_key, ArithUUID)
159+
return [ArithUUID(int=i) for i in checkpoints]
160+
155161
return split_space(self.min_key, self.max_key, count)
156162

157163
def segment_by_checkpoints(self, checkpoints: List[DbKey]) -> List["TableSegment"]:
@@ -297,9 +303,12 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
297303
key_ranges = self._threaded_call("query_key_range", [table1, table2])
298304
mins, maxs = zip(*key_ranges)
299305

306+
key_type = table1._schema["id"]
307+
assert type(key_type) == type(table2._schema["id"])
308+
300309
# We add 1 because our ranges are exclusive of the end (like in Python)
301-
min_key = min(map(int, mins))
302-
max_key = max(map(int, maxs)) + 1
310+
min_key = min(map(key_type.python_type, mins))
311+
max_key = max(map(key_type.python_type, maxs)) + 1
303312

304313
table1 = table1.new(min_key=min_key, max_key=max_key)
305314
table2 = table2.new(min_key=min_key, max_key=max_key)

data_diff/sql.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from runtype import dataclass
88

9-
from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime
9+
from .databases.database_types import AbstractDatabase, DbPath, DbKey, DbTime, ArithUUID
1010

1111

1212
class Sql:
@@ -65,6 +65,8 @@ def compile(self, c: Compiler):
6565
return "b'%s'" % self.value.decode()
6666
elif isinstance(self.value, str):
6767
return "'%s'" % self.value
68+
elif isinstance(self.value, ArithUUID):
69+
return "'%s'" % self.value
6870
return str(self.value)
6971

7072

@@ -75,6 +77,7 @@ class Select(Sql):
7577
where: Sequence[SqlOrStr] = None
7678
order_by: Sequence[SqlOrStr] = None
7779
group_by: Sequence[SqlOrStr] = None
80+
limit: int = None
7881

7982
def compile(self, parent_c: Compiler):
8083
c = parent_c.replace(in_select=True)
@@ -93,6 +96,9 @@ def compile(self, parent_c: Compiler):
9396
if self.order_by:
9497
select += " ORDER BY " + ", ".join(map(c.compile, self.order_by))
9598

99+
if self.limit is not None:
100+
select += " " + c.database.offset_limit(0, self.limit)
101+
96102
if parent_c.in_select:
97103
select = "(%s)" % select
98104
return select

tests/test_diff_tables.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,11 @@ def setUp(self):
259259
"COMMIT",
260260
]
261261

262-
queries.append(f"INSERT INTO a VALUES ('unexpected', '<-- this bad value should not break us')")
262+
self.new_uuid = uuid.uuid1(32132131)
263+
queries.append(f"INSERT INTO a VALUES ('{self.new_uuid}', 'This one is different')")
264+
265+
# TODO test unexpected values?
266+
# queries.append(f"INSERT INTO a VALUES ('unexpected', '<-- this bad value should not break us')")
263267

264268
for query in queries:
265269
self.connection.query(query, None)
@@ -270,7 +274,7 @@ def setUp(self):
270274
def test_string_keys(self):
271275
differ = TableDiffer()
272276
diff = list(differ.diff_tables(self.a, self.b))
273-
breakpoint()
277+
self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))])
274278

275279

276280
class TestTableSegment(TestWithConnection):

0 commit comments

Comments
 (0)