Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit f37f84c

Browse files
committed
Fix for snowflake, mysql, presto, bigquery
1 parent 81325b8 commit f37f84c

File tree

9 files changed

+54
-17
lines changed

9 files changed

+54
-17
lines changed

data_diff/databases/base.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
231231
def parse_table_name(self, name: str) -> DbPath:
232232
return parse_table_name(name)
233233

234+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
235+
if offset:
236+
raise NotImplementedError("No support for OFFSET in query")
237+
238+
return f"LIMIT {limit}"
239+
240+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
241+
return f"TRIM({value})"
242+
234243

235244
class ThreadedDatabase(Database):
236245
"""Access the database through singleton threads.
@@ -267,14 +276,6 @@ def create_connection(self):
267276
def close(self):
268277
self._queue.shutdown()
269278

270-
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
271-
if offset:
272-
raise NotImplementedError("No support for OFFSET in query")
273-
274-
return f"LIMIT {limit}"
275-
276-
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
277-
return f"TRIM({value})"
278279

279280

280281
CHECKSUM_HEXDIGITS = 15 # Must be 15 or lower

data_diff/databases/bigquery.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class BigQuery(Database):
2222
"BIGNUMERIC": Decimal,
2323
"FLOAT64": Float,
2424
"FLOAT32": Float,
25+
# Text
26+
"STRING": Text,
2527
}
2628
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
2729

data_diff/databases/database_types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import decimal
12
from abc import ABC, abstractmethod
23
from typing import Sequence, Optional, Tuple, Union, Dict, Any
34
from datetime import datetime
@@ -54,7 +55,12 @@ class Float(FractionalType):
5455

5556

5657
class Decimal(FractionalType):
57-
pass
58+
@property
59+
def python_type(self) -> type:
60+
if self.precision == 0:
61+
return int
62+
return decimal.Decimal
63+
5864

5965

6066
class StringType(ColType):

data_diff/databases/mysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class MySQL(ThreadedDatabase):
2222
"int": Integer,
2323
# Text
2424
"varchar": Text,
25+
"char": Text,
2526
}
2627
ROUNDS_ON_PREC_LOSS = True
2728

data_diff/databases/oracle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
7070

7171
def _parse_type(
7272
self,
73+
table_name: DbPath,
7374
col_name: str,
7475
type_repr: str,
7576
datetime_precision: int = None,

data_diff/databases/presto.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class Presto(Database):
2929
"integer": Integer,
3030
"real": Float,
3131
"double": Float,
32+
# Text
33+
"varchar": Text,
3234
}
3335
ROUNDS_ON_PREC_LOSS = True
3436

@@ -108,4 +110,17 @@ def _parse_type(
108110
prec, scale = map(int, m.groups())
109111
return n_cls(scale)
110112

113+
string_regexps = {
114+
r"varchar\((\d+)\)": Text,
115+
r"char\((\d+)\)": Text
116+
}
117+
for regexp, n_cls in string_regexps.items():
118+
m = re.match(regexp + "$", type_repr)
119+
if m:
120+
return n_cls()
121+
111122
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)
123+
124+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
125+
# Trim doesn't work on CHAR type
126+
return f"TRIM(CAST({value} AS VARCHAR))"

data_diff/databases/snowflake.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class Snowflake(Database):
2020
# Numbers
2121
"NUMBER": Decimal,
2222
"FLOAT": Float,
23+
# Text
24+
"TEXT": Text,
2325
}
2426
ROUNDS_ON_PREC_LOSS = False
2527

data_diff/diff_tables.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from operator import attrgetter, methodcaller
77
from collections import defaultdict
8-
from typing import List, Tuple, Iterator, Optional
8+
from typing import List, Tuple, Iterator, Optional, Type
99
import logging
1010
from concurrent.futures import ThreadPoolExecutor
1111

@@ -18,6 +18,7 @@
1818
ArithUUID,
1919
NumericType,
2020
PrecisionType,
21+
StringType,
2122
UnknownColType,
2223
Schema,
2324
Schema_CaseInsensitive,
@@ -295,7 +296,8 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
295296
mins, maxs = zip(*key_ranges)
296297

297298
key_type = table1._schema["id"]
298-
assert type(key_type) == type(table2._schema["id"])
299+
key_type2 = table2._schema["id"]
300+
assert key_type.python_type is key_type2.python_type
299301

300302
# We add 1 because our ranges are exclusive of the end (like in Python)
301303
min_key = min(map(key_type.python_type, mins))
@@ -324,7 +326,7 @@ def _validate_and_adjust_columns(self, table1, table2):
324326
col2 = table2._schema[c]
325327
if isinstance(col1, PrecisionType):
326328
if not isinstance(col2, PrecisionType):
327-
raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}")
329+
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
328330

329331
lowest = min(col1, col2, key=attrgetter("precision"))
330332

@@ -336,7 +338,7 @@ def _validate_and_adjust_columns(self, table1, table2):
336338

337339
elif isinstance(col1, NumericType):
338340
if not isinstance(col2, NumericType):
339-
raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}")
341+
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
340342

341343
lowest = min(col1, col2, key=attrgetter("precision"))
342344

@@ -346,6 +348,10 @@ def _validate_and_adjust_columns(self, table1, table2):
346348
table1._schema[c] = col1.replace(precision=lowest.precision)
347349
table2._schema[c] = col2.replace(precision=lowest.precision)
348350

351+
elif isinstance(col1, StringType):
352+
if not isinstance(col2, StringType):
353+
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
354+
349355
for t in [table1, table2]:
350356
for c in t._relevant_columns:
351357
ctype = t._schema[c]

tests/test_database_types.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __iter__(self):
245245
"bignumeric",
246246
],
247247
"uuid": [
248-
"$uuid"
248+
"STRING",
249249
],
250250
},
251251
db.Snowflake: {
@@ -272,7 +272,8 @@ def __iter__(self):
272272
"numeric",
273273
],
274274
"uuid": [
275-
"$uuid"
275+
"varchar",
276+
"varchar(100)",
276277
],
277278

278279
},
@@ -331,7 +332,8 @@ def __iter__(self):
331332
"decimal(30,6)",
332333
],
333334
"uuid": [
334-
"$uuid"
335+
"varchar",
336+
"char(100)",
335337
],
336338

337339
},
@@ -435,7 +437,8 @@ def _drop_table_if_exists(conn, table):
435437
conn.query(f"DROP TABLE {table}", None)
436438
else:
437439
conn.query(f"DROP TABLE IF EXISTS {table}", None)
438-
conn.query("COMMIT", None)
440+
if not isinstance(conn, db.BigQuery):
441+
conn.query("COMMIT", None)
439442

440443

441444
class TestDiffCrossDatabaseTables(unittest.TestCase):

0 commit comments

Comments
 (0)