Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 19b46bf

Browse files
committed
Tests: Refactor TestPerDatabase -> DiffTestCase; auto-create tables
1 parent 1bae429 commit 19b46bf

File tree

4 files changed

+84
-114
lines changed

4 files changed

+84
-114
lines changed

tests/common.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,11 @@ def str_to_checksum(str: str):
129129
return int(md5[half_pos:], 16)
130130

131131

132-
class TestPerDatabase(unittest.TestCase):
132+
class DiffTestCase(unittest.TestCase):
133+
"Sets up two tables for diffing (doesn't create them)"
133134
db_cls = None
135+
src_schema = None
136+
dst_schema = None
134137

135138
def setUp(self):
136139
assert self.db_cls, self.db_cls
@@ -144,12 +147,16 @@ def setUp(self):
144147
self.table_src_path = self.connection.parse_table_name(self.table_src_name)
145148
self.table_dst_path = self.connection.parse_table_name(self.table_dst_name)
146149

147-
self.table_src = ".".join(map(self.connection.dialect.quote, self.table_src_path))
148-
self.table_dst = ".".join(map(self.connection.dialect.quote, self.table_dst_path))
149-
150150
drop_table(self.connection, self.table_src_path)
151151
drop_table(self.connection, self.table_dst_path)
152152

153+
if self.src_schema:
154+
self.src_table = table(self.table_src_path, schema=self.src_schema)
155+
self.connection.query( self.src_table.create() )
156+
if self.dst_schema:
157+
self.dst_table = table(self.table_dst_path, schema=self.dst_schema)
158+
self.connection.query( self.dst_table.create() )
159+
153160
return super().setUp()
154161

155162
def tearDown(self):

tests/sqeleton/test_database.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33

44
from ..common import str_to_checksum, TEST_MYSQL_CONN_STRING
5-
from ..common import str_to_checksum, test_each_database_in_list, TestPerDatabase, get_conn, random_table_suffix
5+
from ..common import str_to_checksum, test_each_database_in_list, DiffTestCase, get_conn, random_table_suffix
66
# from data_diff.sqeleton import databases as db
77
# from data_diff.sqeleton import connect
88

@@ -52,10 +52,11 @@ def test_bad_uris(self):
5252

5353

5454
@test_each_database
55-
class TestSchema(TestPerDatabase):
55+
class TestSchema(unittest.TestCase):
56+
5657
def test_table_list(self):
57-
name = self.table_src_name
58-
db = self.connection
58+
name = 'tbl_' + random_table_suffix()
59+
db = get_conn(self.db_cls)
5960
tbl = table(db.parse_table_name(name), schema={'id': int})
6061
q = db.dialect.list_tables(db.default_schema, name)
6162
assert not db.query(q)

tests/test_diff_tables.py

Lines changed: 58 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from data_diff import databases as db
1414
from data_diff.sqeleton.utils import ArithAlphanumeric, numberToAlphanum
1515

16-
from .common import str_to_checksum, test_each_database_in_list, TestPerDatabase, get_conn, random_table_suffix
16+
from .common import str_to_checksum, test_each_database_in_list, DiffTestCase, get_conn, random_table_suffix
1717

1818

1919
TEST_DATABASES = {
@@ -47,12 +47,13 @@ def test_split_space(self):
4747

4848

4949
@test_each_database
50-
class TestDates(TestPerDatabase):
50+
class TestDates(DiffTestCase):
51+
src_schema = {"id": int, "datetime": datetime, "text_comment": str}
52+
5153
def setUp(self):
5254
super().setUp()
5355

54-
src_table = table(self.table_src_path, schema={"id": int, "datetime": datetime, "text_comment": str})
55-
self.connection.query(src_table.create())
56+
src_table = self.src_table
5657
self.now = now = arrow.get()
5758

5859
rows = [
@@ -143,21 +144,13 @@ def test_offset(self):
143144

144145

145146
@test_each_database
146-
class TestDiffTables(TestPerDatabase):
147+
class TestDiffTables(DiffTestCase):
148+
src_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}
149+
dst_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime}
150+
147151
def setUp(self):
148152
super().setUp()
149153

150-
self.src_table = table(
151-
self.table_src_path,
152-
schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime},
153-
)
154-
self.dst_table = table(
155-
self.table_dst_path,
156-
schema={"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime},
157-
)
158-
159-
self.connection.query([self.src_table.create(), self.dst_table.create(), commit])
160-
161154
self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
162155
self.table2 = _table_segment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False)
163156

@@ -326,14 +319,11 @@ def test_diff_sorted_by_key(self):
326319

327320

328321
@test_each_database
329-
class TestDiffTables2(TestPerDatabase):
330-
def test_diff_column_names(self):
331-
332-
self.src_table = table(self.table_src_path, schema={"id": int, "rating": float, "timestamp": datetime})
333-
self.dst_table = table(self.table_dst_path, schema={"id2": int, "rating2": float, "timestamp2": datetime})
334-
335-
self.connection.query([self.src_table.create(), self.dst_table.create(), commit])
322+
class TestDiffTables2(DiffTestCase):
323+
src_schema = {"id": int, "rating": float, "timestamp": datetime}
324+
dst_schema = {"id2": int, "rating2": float, "timestamp2": datetime}
336325

326+
def test_diff_column_names(self):
337327
time = "2022-01-01 00:00:00"
338328
time2 = "2021-01-01 00:00:00"
339329

@@ -374,17 +364,18 @@ def test_diff_column_names(self):
374364

375365

376366
@test_each_database
377-
class TestUUIDs(TestPerDatabase):
367+
class TestUUIDs(DiffTestCase):
368+
src_schema = {"id": str, "text_comment": str}
369+
378370
def setUp(self):
379371
super().setUp()
380372

381-
self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
373+
src_table = self.src_table
382374

383375
self.new_uuid = uuid.uuid1(32132131)
384376

385377
self.connection.query(
386378
[
387-
src_table.create(),
388379
src_table.insert_rows((uuid.uuid1(i), str(i)) for i in range(100)),
389380
table(self.table_dst_path).create(src_table),
390381
src_table.insert_row(self.new_uuid, "This one is different"),
@@ -416,11 +407,13 @@ def test_where_sampling(self):
416407

417408

418409
@test_each_database_in_list(TEST_DATABASES - {db.MySQL})
419-
class TestAlphanumericKeys(TestPerDatabase):
410+
class TestAlphanumericKeys(DiffTestCase):
411+
src_schema = {"id": str, "text_comment": str}
412+
420413
def setUp(self):
421414
super().setUp()
422415

423-
self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
416+
src_table = self.src_table
424417
self.new_alphanum = "aBcDeFgHiz"
425418

426419
values = []
@@ -433,7 +426,6 @@ def setUp(self):
433426
values.append((str(a), str(i)))
434427

435428
queries = [
436-
src_table.create(),
437429
src_table.insert_rows(values),
438430
table(self.table_dst_path).create(src_table),
439431
src_table.insert_row(self.new_alphanum, "This one is different"),
@@ -461,11 +453,13 @@ def test_alphanum_keys(self):
461453

462454

463455
@test_each_database_in_list(TEST_DATABASES - {db.MySQL})
464-
class TestVaryingAlphanumericKeys(TestPerDatabase):
456+
class TestVaryingAlphanumericKeys(DiffTestCase):
457+
src_schema = {"id": str, "text_comment": str}
458+
465459
def setUp(self):
466460
super().setUp()
467461

468-
self.src_table = src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
462+
src_table = self.src_table
469463

470464
values = []
471465
for i in range(0, 10000, 1000):
@@ -479,7 +473,6 @@ def setUp(self):
479473
self.new_alphanum = "aBcDeFgHiJ"
480474

481475
queries = [
482-
src_table.create(),
483476
src_table.insert_rows(values),
484477
table(self.table_dst_path).create(src_table),
485478
src_table.insert_row(self.new_alphanum, "This one is different"),
@@ -517,7 +510,7 @@ def test_varying_alphanum_keys(self):
517510

518511

519512
@test_each_database
520-
class TestTableSegment(TestPerDatabase):
513+
class TestTableSegment(DiffTestCase):
521514
def setUp(self) -> None:
522515
super().setUp()
523516
self.table = _table_segment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
@@ -550,11 +543,13 @@ def test_case_awareness(self):
550543

551544

552545
@test_each_database
553-
class TestTableUUID(TestPerDatabase):
546+
class TestTableUUID(DiffTestCase):
547+
src_schema = {"id": str, "text_comment": str}
548+
554549
def setUp(self):
555550
super().setUp()
556551

557-
src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
552+
src_table = self.src_table
558553

559554
values = []
560555
for i in range(10):
@@ -565,7 +560,6 @@ def setUp(self):
565560

566561
self.connection.query(
567562
[
568-
src_table.create(),
569563
src_table.insert_rows(values),
570564
table(self.table_dst_path).create(src_table),
571565
src_table.insert_row(self.null_uuid, None),
@@ -583,16 +577,17 @@ def test_uuid_column_with_nulls(self):
583577

584578

585579
@test_each_database
586-
class TestTableNullRowChecksum(TestPerDatabase):
580+
class TestTableNullRowChecksum(DiffTestCase):
581+
src_schema = {"id": str, "text_comment": str}
582+
587583
def setUp(self):
588584
super().setUp()
589585

590-
src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
586+
src_table = self.src_table
591587

592588
self.null_uuid = uuid.uuid1(1)
593589
self.connection.query(
594590
[
595-
src_table.create(),
596591
src_table.insert_row(uuid.uuid1(1), "1"),
597592
table(self.table_dst_path).create(src_table),
598593
src_table.insert_row(self.null_uuid, None), # Add a row where a column has NULL value
@@ -630,13 +625,13 @@ def test_uuid_columns_with_nulls(self):
630625

631626

632627
@test_each_database
633-
class TestConcatMultipleColumnWithNulls(TestPerDatabase):
628+
class TestConcatMultipleColumnWithNulls(DiffTestCase):
629+
src_schema = {"id": str, "c1": str, "c2": str}
630+
dst_schema = {"id": str, "c1": str, "c2": str}
631+
634632
def setUp(self):
635633
super().setUp()
636634

637-
src_table = table(self.table_src_path, schema={"id": str, "c1": str, "c2": str})
638-
dst_table = table(self.table_dst_path, schema={"id": str, "c1": str, "c2": str})
639-
640635
src_values = []
641636
dst_values = []
642637

@@ -654,10 +649,8 @@ def setUp(self):
654649

655650
self.connection.query(
656651
[
657-
src_table.create(),
658-
dst_table.create(),
659-
src_table.insert_rows(src_values),
660-
dst_table.insert_rows(dst_values),
652+
self.src_table.insert_rows(src_values),
653+
self.dst_table.insert_rows(dst_values),
661654
commit,
662655
]
663656
)
@@ -698,13 +691,13 @@ def test_tables_are_different(self):
698691

699692

700693
@test_each_database
701-
class TestTableTableEmpty(TestPerDatabase):
694+
class TestTableTableEmpty(DiffTestCase):
695+
src_schema = {"id": str, "text_comment": str}
696+
dst_schema = {"id": str, "text_comment": str}
697+
702698
def setUp(self):
703699
super().setUp()
704700

705-
self.src_table = table(self.table_src_path, schema={"id": str, "text_comment": str})
706-
self.dst_table = table(self.table_dst_path, schema={"id": str, "text_comment": str})
707-
708701
self.null_uuid = uuid.uuid1(1)
709702

710703
self.diffs = [(uuid.uuid1(i), str(i)) for i in range(100)]
@@ -714,49 +707,34 @@ def setUp(self):
714707

715708
def test_right_table_empty(self):
716709
self.connection.query(
717-
[self.src_table.create(), self.dst_table.create(), self.src_table.insert_rows(self.diffs), commit]
710+
[self.src_table.insert_rows(self.diffs), commit]
718711
)
719712

720713
differ = HashDiffer(bisection_factor=2)
721714
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
722715

723716
def test_left_table_empty(self):
724717
self.connection.query(
725-
[self.src_table.create(), self.dst_table.create(), self.dst_table.insert_rows(self.diffs), commit]
718+
[self.dst_table.insert_rows(self.diffs), commit]
726719
)
727720

728721
differ = HashDiffer(bisection_factor=2)
729722
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
730723

731724

732-
class TestInfoTree(unittest.TestCase):
733-
def test_info_tree_root(self):
734-
try:
735-
self.db = get_conn(db.DuckDB)
736-
except KeyError: # ddb not defined
737-
self.db = get_conn(db.MySQL)
738-
739-
table_suffix = random_table_suffix()
740-
self.table_src_name = f"src{table_suffix}"
741-
self.table_dst_name = f"dst{table_suffix}"
742-
743-
schema = dict(
744-
id=int,
745-
)
746-
self.table1 = table(self.table_src_name, schema=schema)
747-
self.table2 = table(self.table_dst_name, schema=schema)
725+
class TestInfoTree(DiffTestCase):
726+
db_cls = db.MySQL
727+
src_schema = dst_schema = dict(id=int)
748728

749-
queries = [
750-
self.table1.create(),
751-
self.table2.create(),
752-
self.table1.insert_rows([i] for i in range(1000)),
753-
self.table2.insert_rows([i] for i in range(2000)),
754-
]
755-
for q in queries:
756-
self.db.query(q)
757-
758-
ts1 = TableSegment(self.db, self.table1.path, ("id",))
759-
ts2 = TableSegment(self.db, self.table2.path, ("id",))
729+
def test_info_tree_root(self):
730+
db = self.connection
731+
db.query([
732+
self.src_table.insert_rows([i] for i in range(1000)),
733+
self.dst_table.insert_rows([i] for i in range(2000)),
734+
])
735+
736+
ts1 = TableSegment(db, self.src_table.path, ("id",))
737+
ts2 = TableSegment(db, self.dst_table.path, ("id",))
760738

761739
for differ in (HashDiffer(bisection_threshold=64), JoinDiffer(True)):
762740
diff_res = differ.diff_tables(ts1, ts2)

0 commit comments

Comments
 (0)