11from contextlib import suppress
22import unittest
33import time
4- import logging
4+ import re
5+ import math
6+ import datetime
57from decimal import Decimal
6-
78from parameterized import parameterized
89
910from data_diff import databases as db
1011from data_diff .diff_tables import TableDiffer , TableSegment
11- from .common import CONN_STRINGS
12-
12+ from .common import CONN_STRINGS , N_SAMPLES
1313
14- logging .getLogger ("diff_tables" ).setLevel (logging .ERROR )
15- logging .getLogger ("database" ).setLevel (logging .WARN )
1614
17- CONNS = {k : db .connect_to_uri (v ) for k , v in CONN_STRINGS .items ()}
15+ CONNS = {k : db .connect_to_uri (v , 1 ) for k , v in CONN_STRINGS .items ()}
1816
1917CONNS [db .MySQL ].query ("SET @@session.time_zone='+00:00'" , None )
2018
21- TYPE_SAMPLES = {
22- "int" : [127 , - 3 , - 9 , 37 , 15 , 127 ],
23- "datetime_no_timezone" : [
24- "2020-01-01 15:10:10" ,
25- "2020-02-01 9:9:9" ,
26- "2022-03-01 15:10:01.139" ,
27- "2022-04-01 15:10:02.020409" ,
28- "2022-05-01 15:10:03.003030" ,
29- "2022-06-01 15:10:05.009900" ,
30- ],
31- "float" : [
19+
20+ class PaginatedTable :
21+ # We can't query all the rows at once for large tables. It'll occupy too
22+ # much memory.
23+ RECORDS_PER_BATCH = 1000000
24+
25+ def __init__ (self , table , conn ):
26+ self .table = table
27+ self .conn = conn
28+
29+ def __iter__ (self ):
30+ iter = PaginatedTable (self .table , self .conn )
31+ iter .last_id = 0
32+ iter .values = []
33+ iter .value_index = 0
34+ return iter
35+
36+ def __next__ (self ) -> str :
37+ if self .value_index == len (self .values ): # end of current batch
38+ query = f"SELECT id, col FROM { self .table } WHERE id > { self .last_id } ORDER BY id ASC LIMIT { self .RECORDS_PER_BATCH } "
39+ if isinstance (self .conn , db .Oracle ):
40+ query = f"SELECT id, col FROM { self .table } WHERE id > { self .last_id } ORDER BY id ASC OFFSET 0 ROWS FETCH NEXT { self .RECORDS_PER_BATCH } ROWS ONLY"
41+
42+ self .values = self .conn .query (query , list )
43+ if len (self .values ) == 0 : # we must be done!
44+ raise StopIteration
45+ self .last_id = self .values [- 1 ][0 ]
46+ self .value_index = 0
47+
48+ this_value = self .values [self .value_index ]
49+ self .value_index += 1
50+ return this_value
51+
52+
53+ class DateTimeFaker :
54+ MANUAL_FAKES = [
55+ datetime .datetime .fromisoformat ("2020-01-01 15:10:10" ),
56+ datetime .datetime .fromisoformat ("2020-02-01 09:09:09" ),
57+ datetime .datetime .fromisoformat ("2022-03-01 15:10:01.139" ),
58+ datetime .datetime .fromisoformat ("2022-04-01 15:10:02.020409" ),
59+ datetime .datetime .fromisoformat ("2022-05-01 15:10:03.003030" ),
60+ datetime .datetime .fromisoformat ("2022-06-01 15:10:05.009900" ),
61+ ]
62+
63+ def __init__ (self , max ):
64+ self .max = max
65+
66+ def __iter__ (self ):
67+ iter = DateTimeFaker (self .max )
68+ iter .prev = datetime .datetime (2000 , 1 , 1 , 0 , 0 , 0 , 0 )
69+ iter .i = 0
70+ return iter
71+
72+ def __len__ (self ):
73+ return self .max
74+
75+ def __next__ (self ) -> datetime .datetime :
76+ if self .i < len (self .MANUAL_FAKES ):
77+ fake = self .MANUAL_FAKES [self .i ]
78+ self .i += 1
79+ return fake
80+ elif self .i < self .max :
81+ self .prev = self .prev + datetime .timedelta (seconds = 3 , microseconds = 571 )
82+ self .i += 1
83+ return self .prev
84+ else :
85+ raise StopIteration
86+
87+
88+ class IntFaker :
89+ MANUAL_FAKES = [127 , - 3 , - 9 , 37 , 15 , 127 ]
90+
91+ def __init__ (self , max ):
92+ self .max = max
93+
94+ def __iter__ (self ):
95+ iter = IntFaker (self .max )
96+ iter .prev = - 128
97+ iter .i = 0
98+ return iter
99+
100+ def __len__ (self ):
101+ return self .max
102+
103+ def __next__ (self ) -> int :
104+ if self .i < len (self .MANUAL_FAKES ):
105+ fake = self .MANUAL_FAKES [self .i ]
106+ self .i += 1
107+ return fake
108+ elif self .i < self .max :
109+ self .prev += 1
110+ self .i += 1
111+ return self .prev
112+ else :
113+ raise StopIteration
114+
115+
116+ class FloatFaker :
117+ MANUAL_FAKES = [
32118 0.0 ,
33119 0.1 ,
34120 0.00188 ,
45131 1 / 1094893892389 ,
46132 1 / 10948938923893289 ,
47133 3.141592653589793 ,
48- ],
134+ ]
135+
136+ def __init__ (self , max ):
137+ self .max = max
138+
139+ def __iter__ (self ):
140+ iter = FloatFaker (self .max )
141+ iter .prev = - 10.0001
142+ iter .i = 0
143+ return iter
144+
145+ def __len__ (self ):
146+ return self .max
147+
148+ def __next__ (self ) -> float :
149+ if self .i < len (self .MANUAL_FAKES ):
150+ fake = self .MANUAL_FAKES [self .i ]
151+ self .i += 1
152+ return fake
153+ elif self .i < self .max :
154+ self .prev += 0.00571
155+ self .i += 1
156+ return self .prev
157+ else :
158+ raise StopIteration
159+
160+
161+ TYPE_SAMPLES = {
162+ "int" : IntFaker (N_SAMPLES ),
163+ "datetime_no_timezone" : DateTimeFaker (N_SAMPLES ),
164+ "float" : FloatFaker (N_SAMPLES ),
49165}
50166
51167DATABASE_TYPES = {
52168 db .PostgreSQL : {
53169 # https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
54170 "int" : [
55171 # "smallint", # 2 bytes
56- # "int", # 4 bytes
172+ "int" , # 4 bytes
57173 # "bigint", # 8 bytes
58174 ],
59175 # https://www.postgresql.org/docs/current/datatype-datetime.html
76192 # "tinyint", # 1 byte
77193 # "smallint", # 2 bytes
78194 # "mediumint", # 3 bytes
79- # "int", # 4 bytes
195+ "int" , # 4 bytes
80196 # "bigint", # 8 bytes
81197 ],
82198 # https://dev.mysql.com/doc/refman/8.0/en/datetime.html
96212 ],
97213 },
98214 db .BigQuery : {
215+ "int" : ["int" ],
99216 "datetime_no_timezone" : [
100217 "timestamp" ,
101218 # "datetime",
110227 # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint
111228 "int" : [
112229 # all 38 digits with 0 precision, don't need to test all
113- # "int",
230+ "int" ,
114231 # "integer",
115232 # "bigint",
116233 # "smallint",
132249 },
133250 db .Redshift : {
134251 "int" : [
135- # "int",
252+ "int" ,
136253 ],
137254 "datetime_no_timezone" : [
138255 "TIMESTAMP" ,
146263 },
147264 db .Oracle : {
148265 "int" : [
149- # "int",
266+ "int" ,
150267 ],
151268 "datetime_no_timezone" : [
152269 "timestamp with local time zone" ,
163280 # "tinyint", # 1 byte
164281 # "smallint", # 2 bytes
165282 # "mediumint", # 3 bytes
166- # "int", # 4 bytes
283+ "int" , # 4 bytes
167284 # "bigint", # 8 bytes
168285 ],
169286 "datetime_no_timezone" : [
170- "timestamp(6)" ,
171- "timestamp(3)" ,
172- "timestamp(0)" ,
173287 "timestamp" ,
174- "datetime(6) " ,
288+ "timestamp with time zone " ,
175289 ],
176290 "float" : [
177291 "real" ,
203317 )
204318 )
205319
320+
321+ def sanitize (name ):
322+ name = name .lower ()
323+ name = re .sub (r"[\(\)]" , "" , name ) # timestamp(9) -> timestamp9
324+ # Try to shorten long fields, due to length limitations in some DBs
325+ name = name .replace (r"without time zone" , "n_tz" )
326+ name = name .replace (r"with time zone" , "y_tz" )
327+ name = name .replace (r"with local time zone" , "y_tz" )
328+ name = name .replace (r"timestamp" , "ts" )
329+ return parameterized .to_safe_name (name )
330+
331+
332+ def number_to_human (n ):
333+ millnames = ["" , "k" , "m" , "b" ]
334+ n = float (n )
335+ millidx = max (
336+ 0 ,
337+ min (len (millnames ) - 1 , int (math .floor (0 if n == 0 else math .log10 (abs (n )) / 3 ))),
338+ )
339+
340+ return "{:.0f}{}" .format (n / 10 ** (3 * millidx ), millnames [millidx ])
341+
342+
206343# Pass --verbose to test run to get a nice output.
207344def expand_params (testcase_func , param_num , param ):
208345 source_db , target_db , source_type , target_type , type_category = param .args
209346 source_db_type = source_db .__name__
210347 target_db_type = target_db .__name__
211- return "%s_%s_%s_to_%s_%s" % (
348+ name = "%s_%s_%s_to_%s_ %s_%s" % (
212349 testcase_func .__name__ ,
213- source_db_type ,
214- parameterized .to_safe_name (source_type ),
215- target_db_type ,
216- parameterized .to_safe_name (target_type ),
350+ sanitize (source_db_type ),
351+ sanitize (source_type ),
352+ sanitize (target_db_type ),
353+ sanitize (target_type ),
354+ number_to_human (N_SAMPLES ),
217355 )
356+ return name
218357
219358
220359def _insert_to_table (conn , table , values ):
@@ -232,8 +371,10 @@ def _insert_to_table(conn, table, values):
232371 else :
233372 insertion_query += " VALUES "
234373 for j , sample in values :
235- if isinstance (sample , (float , Decimal )):
374+ if isinstance (sample , (float , Decimal , int )):
236375 value = str (sample )
376+ elif isinstance (sample , datetime .datetime ) and isinstance (conn , db .Presto ):
377+ value = f"timestamp '{ sample } '"
237378 else :
238379 value = f"'{ sample } '"
239380 insertion_query += f"({ j } , { value } ),"
@@ -253,6 +394,7 @@ def _drop_table_if_exists(conn, table):
253394 conn .query (f"DROP TABLE { table } " , None )
254395 else :
255396 conn .query (f"DROP TABLE IF EXISTS { table } " , None )
397+ conn .query ("COMMIT" , None )
256398
257399
258400class TestDiffCrossDatabaseTables (unittest .TestCase ):
@@ -266,9 +408,9 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
266408 self .connections = [self .src_conn , self .dst_conn ]
267409 sample_values = TYPE_SAMPLES [type_category ]
268410
269- # Limit in MySQL is 64
270- src_table_name = f"src_{ self ._testMethodName [: 60 ]} "
271- dst_table_name = f"dst_{ self ._testMethodName [: 60 ]} "
411+ # Limit in MySQL is 64, Presto seems to be 63
412+ src_table_name = f"src_{ self ._testMethodName [11 : ]} "
413+ dst_table_name = f"dst_{ self ._testMethodName [11 : ]} "
272414
273415 src_table_path = src_conn .parse_table_name (src_table_name )
274416 dst_table_path = dst_conn .parse_table_name (dst_table_name )
@@ -279,7 +421,7 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego
279421 src_conn .query (f"CREATE TABLE { src_table } (id int, col { source_type } )" , None )
280422 _insert_to_table (src_conn , src_table , enumerate (sample_values , 1 ))
281423
282- values_in_source = src_conn . query ( f"SELECT id, col FROM { src_table } " , list )
424+ values_in_source = PaginatedTable ( src_table , src_conn )
283425
284426 _drop_table_if_exists (dst_conn , dst_table )
285427 dst_conn .query (f"CREATE TABLE { dst_table } (id int, col { target_type } )" , None )
0 commit comments