1313from data_diff import databases as db
1414from data_diff .sqeleton .utils import ArithAlphanumeric , numberToAlphanum
1515
16- from .common import str_to_checksum , test_each_database_in_list , TestPerDatabase , get_conn , random_table_suffix
16+ from .common import str_to_checksum , test_each_database_in_list , DiffTestCase , get_conn , random_table_suffix
1717
1818
1919TEST_DATABASES = {
@@ -47,12 +47,13 @@ def test_split_space(self):
4747
4848
4949@test_each_database
50- class TestDates (TestPerDatabase ):
50+ class TestDates (DiffTestCase ):
51+ src_schema = {"id" : int , "datetime" : datetime , "text_comment" : str }
52+
5153 def setUp (self ):
5254 super ().setUp ()
5355
54- src_table = table (self .table_src_path , schema = {"id" : int , "datetime" : datetime , "text_comment" : str })
55- self .connection .query (src_table .create ())
56+ src_table = self .src_table
5657 self .now = now = arrow .get ()
5758
5859 rows = [
@@ -143,21 +144,13 @@ def test_offset(self):
143144
144145
145146@test_each_database
146- class TestDiffTables (TestPerDatabase ):
147+ class TestDiffTables (DiffTestCase ):
148+ src_schema = {"id" : int , "userid" : int , "movieid" : int , "rating" : float , "timestamp" : datetime }
149+ dst_schema = {"id" : int , "userid" : int , "movieid" : int , "rating" : float , "timestamp" : datetime }
150+
147151 def setUp (self ):
148152 super ().setUp ()
149153
150- self .src_table = table (
151- self .table_src_path ,
152- schema = {"id" : int , "userid" : int , "movieid" : int , "rating" : float , "timestamp" : datetime },
153- )
154- self .dst_table = table (
155- self .table_dst_path ,
156- schema = {"id" : int , "userid" : int , "movieid" : int , "rating" : float , "timestamp" : datetime },
157- )
158-
159- self .connection .query ([self .src_table .create (), self .dst_table .create (), commit ])
160-
161154 self .table = _table_segment (self .connection , self .table_src_path , "id" , "timestamp" , case_sensitive = False )
162155 self .table2 = _table_segment (self .connection , self .table_dst_path , "id" , "timestamp" , case_sensitive = False )
163156
@@ -326,14 +319,11 @@ def test_diff_sorted_by_key(self):
326319
327320
328321@test_each_database
329- class TestDiffTables2 (TestPerDatabase ):
330- def test_diff_column_names (self ):
331-
332- self .src_table = table (self .table_src_path , schema = {"id" : int , "rating" : float , "timestamp" : datetime })
333- self .dst_table = table (self .table_dst_path , schema = {"id2" : int , "rating2" : float , "timestamp2" : datetime })
334-
335- self .connection .query ([self .src_table .create (), self .dst_table .create (), commit ])
322+ class TestDiffTables2 (DiffTestCase ):
323+ src_schema = {"id" : int , "rating" : float , "timestamp" : datetime }
324+ dst_schema = {"id2" : int , "rating2" : float , "timestamp2" : datetime }
336325
326+ def test_diff_column_names (self ):
337327 time = "2022-01-01 00:00:00"
338328 time2 = "2021-01-01 00:00:00"
339329
@@ -374,17 +364,18 @@ def test_diff_column_names(self):
374364
375365
376366@test_each_database
377- class TestUUIDs (TestPerDatabase ):
367+ class TestUUIDs (DiffTestCase ):
368+ src_schema = {"id" : str , "text_comment" : str }
369+
378370 def setUp (self ):
379371 super ().setUp ()
380372
381- self . src_table = src_table = table ( self .table_src_path , schema = { "id" : str , "text_comment" : str })
373+ src_table = self .src_table
382374
383375 self .new_uuid = uuid .uuid1 (32132131 )
384376
385377 self .connection .query (
386378 [
387- src_table .create (),
388379 src_table .insert_rows ((uuid .uuid1 (i ), str (i )) for i in range (100 )),
389380 table (self .table_dst_path ).create (src_table ),
390381 src_table .insert_row (self .new_uuid , "This one is different" ),
@@ -416,11 +407,13 @@ def test_where_sampling(self):
416407
417408
418409@test_each_database_in_list (TEST_DATABASES - {db .MySQL })
419- class TestAlphanumericKeys (TestPerDatabase ):
410+ class TestAlphanumericKeys (DiffTestCase ):
411+ src_schema = {"id" : str , "text_comment" : str }
412+
420413 def setUp (self ):
421414 super ().setUp ()
422415
423- self . src_table = src_table = table ( self .table_src_path , schema = { "id" : str , "text_comment" : str })
416+ src_table = self .src_table
424417 self .new_alphanum = "aBcDeFgHiz"
425418
426419 values = []
@@ -433,7 +426,6 @@ def setUp(self):
433426 values .append ((str (a ), str (i )))
434427
435428 queries = [
436- src_table .create (),
437429 src_table .insert_rows (values ),
438430 table (self .table_dst_path ).create (src_table ),
439431 src_table .insert_row (self .new_alphanum , "This one is different" ),
@@ -461,11 +453,13 @@ def test_alphanum_keys(self):
461453
462454
463455@test_each_database_in_list (TEST_DATABASES - {db .MySQL })
464- class TestVaryingAlphanumericKeys (TestPerDatabase ):
456+ class TestVaryingAlphanumericKeys (DiffTestCase ):
457+ src_schema = {"id" : str , "text_comment" : str }
458+
465459 def setUp (self ):
466460 super ().setUp ()
467461
468- self . src_table = src_table = table ( self .table_src_path , schema = { "id" : str , "text_comment" : str })
462+ src_table = self .src_table
469463
470464 values = []
471465 for i in range (0 , 10000 , 1000 ):
@@ -479,7 +473,6 @@ def setUp(self):
479473 self .new_alphanum = "aBcDeFgHiJ"
480474
481475 queries = [
482- src_table .create (),
483476 src_table .insert_rows (values ),
484477 table (self .table_dst_path ).create (src_table ),
485478 src_table .insert_row (self .new_alphanum , "This one is different" ),
@@ -517,7 +510,7 @@ def test_varying_alphanum_keys(self):
517510
518511
519512@test_each_database
520- class TestTableSegment (TestPerDatabase ):
513+ class TestTableSegment (DiffTestCase ):
521514 def setUp (self ) -> None :
522515 super ().setUp ()
523516 self .table = _table_segment (self .connection , self .table_src_path , "id" , "timestamp" , case_sensitive = False )
@@ -550,11 +543,13 @@ def test_case_awareness(self):
550543
551544
552545@test_each_database
553- class TestTableUUID (TestPerDatabase ):
546+ class TestTableUUID (DiffTestCase ):
547+ src_schema = {"id" : str , "text_comment" : str }
548+
554549 def setUp (self ):
555550 super ().setUp ()
556551
557- src_table = table ( self .table_src_path , schema = { "id" : str , "text_comment" : str })
552+ src_table = self .src_table
558553
559554 values = []
560555 for i in range (10 ):
@@ -565,7 +560,6 @@ def setUp(self):
565560
566561 self .connection .query (
567562 [
568- src_table .create (),
569563 src_table .insert_rows (values ),
570564 table (self .table_dst_path ).create (src_table ),
571565 src_table .insert_row (self .null_uuid , None ),
@@ -583,16 +577,17 @@ def test_uuid_column_with_nulls(self):
583577
584578
585579@test_each_database
586- class TestTableNullRowChecksum (TestPerDatabase ):
580+ class TestTableNullRowChecksum (DiffTestCase ):
581+ src_schema = {"id" : str , "text_comment" : str }
582+
587583 def setUp (self ):
588584 super ().setUp ()
589585
590- src_table = table ( self .table_src_path , schema = { "id" : str , "text_comment" : str })
586+ src_table = self .src_table
591587
592588 self .null_uuid = uuid .uuid1 (1 )
593589 self .connection .query (
594590 [
595- src_table .create (),
596591 src_table .insert_row (uuid .uuid1 (1 ), "1" ),
597592 table (self .table_dst_path ).create (src_table ),
598593 src_table .insert_row (self .null_uuid , None ), # Add a row where a column has NULL value
@@ -630,13 +625,13 @@ def test_uuid_columns_with_nulls(self):
630625
631626
632627@test_each_database
633- class TestConcatMultipleColumnWithNulls (TestPerDatabase ):
628+ class TestConcatMultipleColumnWithNulls (DiffTestCase ):
629+ src_schema = {"id" : str , "c1" : str , "c2" : str }
630+ dst_schema = {"id" : str , "c1" : str , "c2" : str }
631+
634632 def setUp (self ):
635633 super ().setUp ()
636634
637- src_table = table (self .table_src_path , schema = {"id" : str , "c1" : str , "c2" : str })
638- dst_table = table (self .table_dst_path , schema = {"id" : str , "c1" : str , "c2" : str })
639-
640635 src_values = []
641636 dst_values = []
642637
@@ -654,10 +649,8 @@ def setUp(self):
654649
655650 self .connection .query (
656651 [
657- src_table .create (),
658- dst_table .create (),
659- src_table .insert_rows (src_values ),
660- dst_table .insert_rows (dst_values ),
652+ self .src_table .insert_rows (src_values ),
653+ self .dst_table .insert_rows (dst_values ),
661654 commit ,
662655 ]
663656 )
@@ -698,13 +691,13 @@ def test_tables_are_different(self):
698691
699692
700693@test_each_database
701- class TestTableTableEmpty (TestPerDatabase ):
694+ class TestTableTableEmpty (DiffTestCase ):
695+ src_schema = {"id" : str , "text_comment" : str }
696+ dst_schema = {"id" : str , "text_comment" : str }
697+
702698 def setUp (self ):
703699 super ().setUp ()
704700
705- self .src_table = table (self .table_src_path , schema = {"id" : str , "text_comment" : str })
706- self .dst_table = table (self .table_dst_path , schema = {"id" : str , "text_comment" : str })
707-
708701 self .null_uuid = uuid .uuid1 (1 )
709702
710703 self .diffs = [(uuid .uuid1 (i ), str (i )) for i in range (100 )]
@@ -714,49 +707,34 @@ def setUp(self):
714707
715708 def test_right_table_empty (self ):
716709 self .connection .query (
717- [self .src_table .create (), self . dst_table . create (), self . src_table . insert_rows (self .diffs ), commit ]
710+ [self .src_table .insert_rows (self .diffs ), commit ]
718711 )
719712
720713 differ = HashDiffer (bisection_factor = 2 )
721714 self .assertRaises (ValueError , list , differ .diff_tables (self .a , self .b ))
722715
723716 def test_left_table_empty (self ):
724717 self .connection .query (
725- [self .src_table . create (), self . dst_table . create (), self . dst_table .insert_rows (self .diffs ), commit ]
718+ [self .dst_table .insert_rows (self .diffs ), commit ]
726719 )
727720
728721 differ = HashDiffer (bisection_factor = 2 )
729722 self .assertRaises (ValueError , list , differ .diff_tables (self .a , self .b ))
730723
731724
732- class TestInfoTree (unittest .TestCase ):
733- def test_info_tree_root (self ):
734- try :
735- self .db = get_conn (db .DuckDB )
736- except KeyError : # ddb not defined
737- self .db = get_conn (db .MySQL )
738-
739- table_suffix = random_table_suffix ()
740- self .table_src_name = f"src{ table_suffix } "
741- self .table_dst_name = f"dst{ table_suffix } "
742-
743- schema = dict (
744- id = int ,
745- )
746- self .table1 = table (self .table_src_name , schema = schema )
747- self .table2 = table (self .table_dst_name , schema = schema )
725+ class TestInfoTree (DiffTestCase ):
726+ db_cls = db .MySQL
727+ src_schema = dst_schema = dict (id = int )
748728
749- queries = [
750- self .table1 .create (),
751- self .table2 .create (),
752- self .table1 .insert_rows ([i ] for i in range (1000 )),
753- self .table2 .insert_rows ([i ] for i in range (2000 )),
754- ]
755- for q in queries :
756- self .db .query (q )
757-
758- ts1 = TableSegment (self .db , self .table1 .path , ("id" ,))
759- ts2 = TableSegment (self .db , self .table2 .path , ("id" ,))
729+ def test_info_tree_root (self ):
730+ db = self .connection
731+ db .query ([
732+ self .src_table .insert_rows ([i ] for i in range (1000 )),
733+ self .dst_table .insert_rows ([i ] for i in range (2000 )),
734+ ])
735+
736+ ts1 = TableSegment (db , self .src_table .path , ("id" ,))
737+ ts2 = TableSegment (db , self .dst_table .path , ("id" ,))
760738
761739 for differ in (HashDiffer (bisection_threshold = 64 ), JoinDiffer (True )):
762740 diff_res = differ .diff_tables (ts1 , ts2 )
0 commit comments