Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 860e9b6

Browse files
authored
Merge pull request #114 from datafold/fix_presto_tests
Fixes for presto tests, all passing now; + small cleanup
2 parents 0fe259d + 17252ac commit 860e9b6

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

data_diff/databases/database_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ class NumericType(ColType):
4040
# 'precision' signifies how many fractional digits (after the dot) we want to compare
4141
precision: int
4242

43+
4344
class FractionalType(NumericType):
4445
pass
4546

47+
4648
class Float(FractionalType):
4749
pass
4850

data_diff/diff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _update_column(self):
9696

9797
def _quote_column(self, c):
9898
if self._schema:
99-
c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive.
99+
c = self._schema.get_key(c) # Get the actual name. Might be case-insensitive.
100100
return self.database.quote(c)
101101

102102
def with_schema(self) -> "TableSegment":

tests/test_api.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,8 @@
88

99

1010
class TestApi(unittest.TestCase):
11-
@classmethod
12-
def setUpClass(cls):
13-
# Avoid leaking connections that require waiting for the GC, which can
14-
# cause deadlocks for table-level modifications.
15-
cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
16-
1711
def setUp(self) -> None:
18-
# self.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
12+
self.preql = preql.Preql(TEST_MYSQL_CONN_STRING)
1913
self.preql(
2014
r"""
2115
table test_api {

tests/test_database_types.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import re
55
import math
6-
import datetime
6+
from datetime import datetime, timedelta
77
from decimal import Decimal
88
from parameterized import parameterized
99

@@ -52,33 +52,33 @@ def __next__(self) -> str:
5252

5353
class DateTimeFaker:
5454
MANUAL_FAKES = [
55-
datetime.datetime.fromisoformat("2020-01-01 15:10:10"),
56-
datetime.datetime.fromisoformat("2020-02-01 09:09:09"),
57-
datetime.datetime.fromisoformat("2022-03-01 15:10:01.139"),
58-
datetime.datetime.fromisoformat("2022-04-01 15:10:02.020409"),
59-
datetime.datetime.fromisoformat("2022-05-01 15:10:03.003030"),
60-
datetime.datetime.fromisoformat("2022-06-01 15:10:05.009900"),
55+
datetime.fromisoformat("2020-01-01 15:10:10"),
56+
datetime.fromisoformat("2020-02-01 09:09:09"),
57+
datetime.fromisoformat("2022-03-01 15:10:01.139"),
58+
datetime.fromisoformat("2022-04-01 15:10:02.020409"),
59+
datetime.fromisoformat("2022-05-01 15:10:03.003030"),
60+
datetime.fromisoformat("2022-06-01 15:10:05.009900"),
6161
]
6262

6363
def __init__(self, max):
6464
self.max = max
6565

6666
def __iter__(self):
6767
iter = DateTimeFaker(self.max)
68-
iter.prev = datetime.datetime(2000, 1, 1, 0, 0, 0, 0)
68+
iter.prev = datetime(2000, 1, 1, 0, 0, 0, 0)
6969
iter.i = 0
7070
return iter
7171

7272
def __len__(self):
7373
return self.max
7474

75-
def __next__(self) -> datetime.datetime:
75+
def __next__(self) -> datetime:
7676
if self.i < len(self.MANUAL_FAKES):
7777
fake = self.MANUAL_FAKES[self.i]
7878
self.i += 1
7979
return fake
8080
elif self.i < self.max:
81-
self.prev = self.prev + datetime.timedelta(seconds=3, microseconds=571)
81+
self.prev = self.prev + timedelta(seconds=3, microseconds=571)
8282
self.i += 1
8383
return self.prev
8484
else:
@@ -373,7 +373,7 @@ def _insert_to_table(conn, table, values):
373373
for j, sample in values:
374374
if isinstance(sample, (float, Decimal, int)):
375375
value = str(sample)
376-
elif isinstance(sample, datetime.datetime) and isinstance(conn, db.Presto):
376+
elif isinstance(sample, datetime) and isinstance(conn, db.Presto):
377377
value = f"timestamp '{sample}'"
378378
else:
379379
value = f"'{sample}'"
@@ -422,6 +422,11 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
422422
_insert_to_table(src_conn, src_table, enumerate(sample_values, 1))
423423

424424
values_in_source = PaginatedTable(src_table, src_conn)
425+
if source_db is db.Presto:
426+
if source_type.startswith("decimal"):
427+
values_in_source = [(a, Decimal(b)) for a, b in values_in_source]
428+
elif source_type.startswith("timestamp"):
429+
values_in_source = [(a, datetime.fromisoformat(b.rstrip(" UTC"))) for a, b in values_in_source]
425430

426431
_drop_table_if_exists(dst_conn, dst_table)
427432
dst_conn.query(f"CREATE TABLE {dst_table}(id int, col {target_type})", None)

0 commit comments

Comments
 (0)