Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 0fe259d

Browse files
authored
Merge pull request #106 from datafold/presto-and-fakers
2 parents 8661a14 + 5a173fd commit 0fe259d

File tree

4 files changed

+201
-43
lines changed

4 files changed

+201
-43
lines changed

data_diff/databases/presto.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ def to_string(self, s: str):
5050

5151
def _query(self, sql_code: str) -> list:
5252
"Uses the standard SQL cursor interface"
53-
return _query_conn(self._conn, sql_code)
53+
c = self._conn.cursor()
54+
c.execute(sql_code)
55+
if sql_code.lower().startswith("select"):
56+
return c.fetchall()
57+
# Required for the query to actually run 🤯
58+
if re.match(r"(insert|create|truncate|drop)", sql_code, re.IGNORECASE):
59+
return c.fetchone()
5460

5561
def close(self):
5662
self._conn.close()
@@ -88,7 +94,7 @@ def _parse_type(
8894
datetime_precision = int(m.group(1))
8995
return t_cls(
9096
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
91-
rounds=False,
97+
rounds=self.ROUNDS_ON_PREC_LOSS,
9298
)
9399

94100
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}

dev/presto-conf/standalone/catalog/postgresql.properties

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ connector.name=postgresql
22
connection-url=jdbc:postgresql://postgres:5432/postgres
33
connection-user=postgres
44
connection-password=Password1
5+
allow-drop-table=true

tests/common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import hashlib
2+
import os
23

34
from data_diff import databases as db
45
import logging
56

6-
logging.basicConfig(level=logging.INFO)
7-
87
TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql"
98
TEST_POSTGRESQL_CONN_STRING: str = None
109
TEST_SNOWFLAKE_CONN_STRING: str = None
@@ -13,6 +12,16 @@
1312
TEST_ORACLE_CONN_STRING: str = None
1413
TEST_PRESTO_CONN_STRING: str = None
1514

15+
DEFAULT_N_SAMPLES = 50
16+
N_SAMPLES = int(os.environ.get("N_SAMPLES", DEFAULT_N_SAMPLES))
17+
18+
level = logging.ERROR
19+
if os.environ.get("LOG_LEVEL", False):
20+
level = getattr(logging, os.environ["LOG_LEVEL"].upper())
21+
22+
logging.basicConfig(level=level)
23+
logging.getLogger("diff_tables").setLevel(level)
24+
logging.getLogger("database").setLevel(level)
1625

1726
try:
1827
from .local_settings import *

tests/test_database_types.py

Lines changed: 181 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,120 @@
11
from contextlib import suppress
22
import unittest
33
import time
4-
import logging
4+
import re
5+
import math
6+
import datetime
57
from decimal import Decimal
6-
78
from parameterized import parameterized
89

910
from data_diff import databases as db
1011
from data_diff.diff_tables import TableDiffer, TableSegment
11-
from .common import CONN_STRINGS
12-
12+
from .common import CONN_STRINGS, N_SAMPLES
1313

14-
logging.getLogger("diff_tables").setLevel(logging.ERROR)
15-
logging.getLogger("database").setLevel(logging.WARN)
1614

17-
CONNS = {k: db.connect_to_uri(v) for k, v in CONN_STRINGS.items()}
15+
CONNS = {k: db.connect_to_uri(v, 1) for k, v in CONN_STRINGS.items()}
1816

1917
CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
2018

21-
TYPE_SAMPLES = {
22-
"int": [127, -3, -9, 37, 15, 127],
23-
"datetime_no_timezone": [
24-
"2020-01-01 15:10:10",
25-
"2020-02-01 9:9:9",
26-
"2022-03-01 15:10:01.139",
27-
"2022-04-01 15:10:02.020409",
28-
"2022-05-01 15:10:03.003030",
29-
"2022-06-01 15:10:05.009900",
30-
],
31-
"float": [
19+
20+
class PaginatedTable:
21+
# We can't query all the rows at once for large tables. It'll occupy too
22+
# much memory.
23+
RECORDS_PER_BATCH = 1000000
24+
25+
def __init__(self, table, conn):
26+
self.table = table
27+
self.conn = conn
28+
29+
def __iter__(self):
30+
iter = PaginatedTable(self.table, self.conn)
31+
iter.last_id = 0
32+
iter.values = []
33+
iter.value_index = 0
34+
return iter
35+
36+
def __next__(self) -> str:
37+
if self.value_index == len(self.values): # end of current batch
38+
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC LIMIT {self.RECORDS_PER_BATCH}"
39+
if isinstance(self.conn, db.Oracle):
40+
query = f"SELECT id, col FROM {self.table} WHERE id > {self.last_id} ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT {self.RECORDS_PER_BATCH} ROWS ONLY"
41+
42+
self.values = self.conn.query(query, list)
43+
if len(self.values) == 0: # we must be done!
44+
raise StopIteration
45+
self.last_id = self.values[-1][0]
46+
self.value_index = 0
47+
48+
this_value = self.values[self.value_index]
49+
self.value_index += 1
50+
return this_value
51+
52+
53+
class DateTimeFaker:
54+
MANUAL_FAKES = [
55+
datetime.datetime.fromisoformat("2020-01-01 15:10:10"),
56+
datetime.datetime.fromisoformat("2020-02-01 09:09:09"),
57+
datetime.datetime.fromisoformat("2022-03-01 15:10:01.139"),
58+
datetime.datetime.fromisoformat("2022-04-01 15:10:02.020409"),
59+
datetime.datetime.fromisoformat("2022-05-01 15:10:03.003030"),
60+
datetime.datetime.fromisoformat("2022-06-01 15:10:05.009900"),
61+
]
62+
63+
def __init__(self, max):
64+
self.max = max
65+
66+
def __iter__(self):
67+
iter = DateTimeFaker(self.max)
68+
iter.prev = datetime.datetime(2000, 1, 1, 0, 0, 0, 0)
69+
iter.i = 0
70+
return iter
71+
72+
def __len__(self):
73+
return self.max
74+
75+
def __next__(self) -> datetime.datetime:
76+
if self.i < len(self.MANUAL_FAKES):
77+
fake = self.MANUAL_FAKES[self.i]
78+
self.i += 1
79+
return fake
80+
elif self.i < self.max:
81+
self.prev = self.prev + datetime.timedelta(seconds=3, microseconds=571)
82+
self.i += 1
83+
return self.prev
84+
else:
85+
raise StopIteration
86+
87+
88+
class IntFaker:
89+
MANUAL_FAKES = [127, -3, -9, 37, 15, 127]
90+
91+
def __init__(self, max):
92+
self.max = max
93+
94+
def __iter__(self):
95+
iter = IntFaker(self.max)
96+
iter.prev = -128
97+
iter.i = 0
98+
return iter
99+
100+
def __len__(self):
101+
return self.max
102+
103+
def __next__(self) -> int:
104+
if self.i < len(self.MANUAL_FAKES):
105+
fake = self.MANUAL_FAKES[self.i]
106+
self.i += 1
107+
return fake
108+
elif self.i < self.max:
109+
self.prev += 1
110+
self.i += 1
111+
return self.prev
112+
else:
113+
raise StopIteration
114+
115+
116+
class FloatFaker:
117+
MANUAL_FAKES = [
32118
0.0,
33119
0.1,
34120
0.00188,
@@ -45,15 +131,45 @@
45131
1 / 1094893892389,
46132
1 / 10948938923893289,
47133
3.141592653589793,
48-
],
134+
]
135+
136+
def __init__(self, max):
137+
self.max = max
138+
139+
def __iter__(self):
140+
iter = FloatFaker(self.max)
141+
iter.prev = -10.0001
142+
iter.i = 0
143+
return iter
144+
145+
def __len__(self):
146+
return self.max
147+
148+
def __next__(self) -> float:
149+
if self.i < len(self.MANUAL_FAKES):
150+
fake = self.MANUAL_FAKES[self.i]
151+
self.i += 1
152+
return fake
153+
elif self.i < self.max:
154+
self.prev += 0.00571
155+
self.i += 1
156+
return self.prev
157+
else:
158+
raise StopIteration
159+
160+
161+
TYPE_SAMPLES = {
162+
"int": IntFaker(N_SAMPLES),
163+
"datetime_no_timezone": DateTimeFaker(N_SAMPLES),
164+
"float": FloatFaker(N_SAMPLES),
49165
}
50166

51167
DATABASE_TYPES = {
52168
db.PostgreSQL: {
53169
# https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
54170
"int": [
55171
# "smallint", # 2 bytes
56-
# "int", # 4 bytes
172+
"int", # 4 bytes
57173
# "bigint", # 8 bytes
58174
],
59175
# https://www.postgresql.org/docs/current/datatype-datetime.html
@@ -76,7 +192,7 @@
76192
# "tinyint", # 1 byte
77193
# "smallint", # 2 bytes
78194
# "mediumint", # 3 bytes
79-
# "int", # 4 bytes
195+
"int", # 4 bytes
80196
# "bigint", # 8 bytes
81197
],
82198
# https://dev.mysql.com/doc/refman/8.0/en/datetime.html
@@ -96,6 +212,7 @@
96212
],
97213
},
98214
db.BigQuery: {
215+
"int": ["int"],
99216
"datetime_no_timezone": [
100217
"timestamp",
101218
# "datetime",
@@ -110,7 +227,7 @@
110227
# https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint
111228
"int": [
112229
# all 38 digits with 0 precision, don't need to test all
113-
# "int",
230+
"int",
114231
# "integer",
115232
# "bigint",
116233
# "smallint",
@@ -132,7 +249,7 @@
132249
},
133250
db.Redshift: {
134251
"int": [
135-
# "int",
252+
"int",
136253
],
137254
"datetime_no_timezone": [
138255
"TIMESTAMP",
@@ -146,7 +263,7 @@
146263
},
147264
db.Oracle: {
148265
"int": [
149-
# "int",
266+
"int",
150267
],
151268
"datetime_no_timezone": [
152269
"timestamp with local time zone",
@@ -163,15 +280,12 @@
163280
# "tinyint", # 1 byte
164281
# "smallint", # 2 bytes
165282
# "mediumint", # 3 bytes
166-
# "int", # 4 bytes
283+
"int", # 4 bytes
167284
# "bigint", # 8 bytes
168285
],
169286
"datetime_no_timezone": [
170-
"timestamp(6)",
171-
"timestamp(3)",
172-
"timestamp(0)",
173287
"timestamp",
174-
"datetime(6)",
288+
"timestamp with time zone",
175289
],
176290
"float": [
177291
"real",
@@ -203,18 +317,43 @@
203317
)
204318
)
205319

320+
321+
def sanitize(name):
322+
name = name.lower()
323+
name = re.sub(r"[\(\)]", "", name) # timestamp(9) -> timestamp9
324+
# Try to shorten long fields, due to length limitations in some DBs
325+
name = name.replace(r"without time zone", "n_tz")
326+
name = name.replace(r"with time zone", "y_tz")
327+
name = name.replace(r"with local time zone", "y_tz")
328+
name = name.replace(r"timestamp", "ts")
329+
return parameterized.to_safe_name(name)
330+
331+
332+
def number_to_human(n):
333+
millnames = ["", "k", "m", "b"]
334+
n = float(n)
335+
millidx = max(
336+
0,
337+
min(len(millnames) - 1, int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3))),
338+
)
339+
340+
return "{:.0f}{}".format(n / 10 ** (3 * millidx), millnames[millidx])
341+
342+
206343
# Pass --verbose to test run to get a nice output.
207344
def expand_params(testcase_func, param_num, param):
208345
source_db, target_db, source_type, target_type, type_category = param.args
209346
source_db_type = source_db.__name__
210347
target_db_type = target_db.__name__
211-
return "%s_%s_%s_to_%s_%s" % (
348+
name = "%s_%s_%s_to_%s_%s_%s" % (
212349
testcase_func.__name__,
213-
source_db_type,
214-
parameterized.to_safe_name(source_type),
215-
target_db_type,
216-
parameterized.to_safe_name(target_type),
350+
sanitize(source_db_type),
351+
sanitize(source_type),
352+
sanitize(target_db_type),
353+
sanitize(target_type),
354+
number_to_human(N_SAMPLES),
217355
)
356+
return name
218357

219358

220359
def _insert_to_table(conn, table, values):
@@ -232,8 +371,10 @@ def _insert_to_table(conn, table, values):
232371
else:
233372
insertion_query += " VALUES "
234373
for j, sample in values:
235-
if isinstance(sample, (float, Decimal)):
374+
if isinstance(sample, (float, Decimal, int)):
236375
value = str(sample)
376+
elif isinstance(sample, datetime.datetime) and isinstance(conn, db.Presto):
377+
value = f"timestamp '{sample}'"
237378
else:
238379
value = f"'{sample}'"
239380
insertion_query += f"({j}, {value}),"
@@ -253,6 +394,7 @@ def _drop_table_if_exists(conn, table):
253394
conn.query(f"DROP TABLE {table}", None)
254395
else:
255396
conn.query(f"DROP TABLE IF EXISTS {table}", None)
397+
conn.query("COMMIT", None)
256398

257399

258400
class TestDiffCrossDatabaseTables(unittest.TestCase):
@@ -266,9 +408,9 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
266408
self.connections = [self.src_conn, self.dst_conn]
267409
sample_values = TYPE_SAMPLES[type_category]
268410

269-
# Limit in MySQL is 64
270-
src_table_name = f"src_{self._testMethodName[:60]}"
271-
dst_table_name = f"dst_{self._testMethodName[:60]}"
411+
# Limit in MySQL is 64, Presto seems to be 63
412+
src_table_name = f"src_{self._testMethodName[11:]}"
413+
dst_table_name = f"dst_{self._testMethodName[11:]}"
272414

273415
src_table_path = src_conn.parse_table_name(src_table_name)
274416
dst_table_path = dst_conn.parse_table_name(dst_table_name)
@@ -279,7 +421,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
279421
src_conn.query(f"CREATE TABLE {src_table}(id int, col {source_type})", None)
280422
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))
281423

282-
values_in_source = src_conn.query(f"SELECT id, col FROM {src_table}", list)
424+
values_in_source = PaginatedTable(src_table, src_conn)
283425

284426
_drop_table_if_exists(dst_conn, dst_table)
285427
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)

0 commit comments

Comments
 (0)