Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit a7d1625

Browse files
authored
Merge pull request #75 from datafold/normalize_types_refactor
Refactor Normalize-types into normalize_timestamp() normalize_number()
2 parents 0232343 + 469f142 commit a7d1625

File tree

1 file changed

+131
-121
lines changed

1 file changed

+131
-121
lines changed

data_diff/database.py

Lines changed: 131 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def __post_init__(self):
131131
class UnknownColType(ColType):
132132
text: str
133133

134+
def __post_init__(self):
135+
logger.warn(f"Column of type '{self.text}' has no compatibility handling. "
136+
"If encoding/formatting differs between databases, it may result in false positives.")
137+
134138

135139
class AbstractDatabase(ABC):
136140
@abstractmethod
@@ -173,16 +177,24 @@ def close(self):
173177
"Close connection(s) to the database instance. Querying will stop functioning."
174178
...
175179

180+
176181
@abstractmethod
177-
def normalize_value_by_type(value: str, coltype: ColType) -> str:
178-
"""Creates an SQL expression, that converts 'value' to a normalized representation.
182+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
183+
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
179184
180-
The returned expression must accept any SQL value, and return a string.
185+
The returned expression must accept any SQL datetime/timestamp, and return a string.
186+
187+
Date format: "YYYY-MM-DD HH:mm:SS.FFFFFF"
188+
189+
Precision of dates should be rounded up/down according to coltype.rounds
190+
"""
191+
...
181192

182-
- Dates are expected in the format:
183-
"YYYY-MM-DD HH:mm:SS.FFFFFF"
193+
@abstractmethod
194+
def normalize_number(self, value: str, coltype: ColType) -> str:
195+
"""Creates an SQL expression, that converts 'value' to a normalized number.
184196
185-
Rounded up/down according to coltype.rounds
197+
The returned expression must accept any SQL int/numeric/float, and return a string.
186198
187199
- Floats/Decimals are expected in the format
188200
"I.P"
@@ -191,14 +203,31 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str:
191203
and must be at least one digit (0).
192204
P is the fractional digits, the amount of which is specified with
193205
coltype.precision. Trailing zeroes may be necessary.
206+
If P is 0, the dot is omitted.
194207
195208
Note: This precision is different than the one used by databases. For decimals,
196-
it's the same as "numeric_scale", and for floats, who use binary precision,
197-
it can be calculated as log10(2**p)
209+
it's the same as ``numeric_scale``, and for floats, who use binary precision,
210+
it can be calculated as ``log10(2**numeric_precision)``.
211+
"""
212+
...
213+
214+
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
215+
"""Creates an SQL expression, that converts 'value' to a normalized representation.
216+
217+
The returned expression must accept any SQL value, and return a string.
218+
219+
The default implementation dispatches to a method according to ``coltype``:
198220
221+
TemporalType -> normalize_timestamp()
222+
NumericType -> normalize_number()
223+
-else- -> to_string()
199224
200225
"""
201-
...
226+
if isinstance(coltype, TemporalType):
227+
return self.normalize_timestamp(value, coltype)
228+
elif isinstance(coltype, NumericType):
229+
return self.normalize_number(value, coltype)
230+
return self.to_string(f"{value}")
202231

203232

204233
class Database(AbstractDatabase):
@@ -209,8 +238,8 @@ class Database(AbstractDatabase):
209238
Instanciated using :meth:`~data_diff.connect_to_uri`
210239
"""
211240

212-
DATETIME_TYPES = NotImplemented
213-
default_schema = NotImplemented
241+
DATETIME_TYPES = {}
242+
default_schema = None
214243

215244
def query(self, sql_ast: SqlOrStr, res_type: type):
216245
"Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
@@ -306,13 +335,15 @@ def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
306335

307336
def _normalize_table_path(self, path: DbPath) -> DbPath:
308337
if len(path) == 1:
309-
return self.default_schema, path[0]
310-
elif len(path) == 2:
311-
return path
338+
if self.default_schema:
339+
return self.default_schema, path[0]
340+
elif len(path) != 2:
341+
raise ValueError(
342+
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
343+
)
344+
345+
return path
312346

313-
raise ValueError(
314-
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
315-
)
316347

317348
def parse_table_name(self, name: str) -> DbPath:
318349
return parse_table_name(name)
@@ -408,27 +439,16 @@ def md5_to_int(self, s: str) -> str:
408439
def to_string(self, s: str):
409440
return f"{s}::varchar"
410441

411-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
412-
if isinstance(coltype, TemporalType):
413-
# if coltype.precision == 0:
414-
# return f"to_char({value}::timestamp(0), 'YYYY-mm-dd HH24:MI:SS')"
415-
# if coltype.precision == 3:
416-
# return f"to_char({value}, 'YYYY-mm-dd HH24:MI:SS.US')"
417-
# elif coltype.precision == 6:
418-
# return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
419-
# else:
420-
# # Postgres/Redshift doesn't support arbitrary precision
421-
# raise TypeError(f"Bad precision for {type(self).__name__}: {coltype})")
422-
if coltype.rounds:
423-
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
424-
else:
425-
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
426-
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
427442

428-
elif isinstance(coltype, NumericType):
429-
value = f"{value}::decimal(38, {coltype.precision})"
443+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
444+
if coltype.rounds:
445+
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
430446

431-
return self.to_string(f"{value}")
447+
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
448+
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
449+
450+
def normalize_number(self, value: str, coltype: ColType) -> str:
451+
return self.to_string(f"{value}::decimal(38, {coltype.precision})")
432452

433453

434454
class Presto(Database):
@@ -468,25 +488,19 @@ def _query(self, sql_code: str) -> list:
468488
def close(self):
469489
self._conn.close()
470490

471-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
472-
if isinstance(coltype, TemporalType):
473-
if coltype.rounds:
474-
if coltype.precision > 3:
475-
pass
476-
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
477-
else:
478-
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
479-
# datetime = f"date_format(cast({value} as timestamp(6), '%Y-%m-%d %H:%i:%S.%f'))"
480-
# datetime = self.to_string(f"cast({value} as datetime(6))")
481-
482-
return (
483-
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
484-
)
491+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
492+
# TODO
493+
if coltype.rounds:
494+
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
495+
else:
496+
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
485497

486-
elif isinstance(coltype, NumericType):
487-
value = f"cast({value} as decimal(38,{coltype.precision}))"
498+
return (
499+
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
500+
)
488501

489-
return self.to_string(value)
502+
def normalize_number(self, value: str, coltype: ColType) -> str:
503+
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
490504

491505
def select_table_schema(self, path: DbPath) -> str:
492506
schema, table = self._normalize_table_path(path)
@@ -575,18 +589,16 @@ def md5_to_int(self, s: str) -> str:
575589
def to_string(self, s: str):
576590
return f"cast({s} as char)"
577591

578-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
579-
if isinstance(coltype, TemporalType):
580-
if coltype.rounds:
581-
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
582-
else:
583-
s = self.to_string(f"cast({value} as datetime(6))")
584-
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
592+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
593+
if coltype.rounds:
594+
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
585595

586-
elif isinstance(coltype, NumericType):
587-
value = f"cast({value} as decimal(38,{coltype.precision}))"
596+
s = self.to_string(f"cast({value} as datetime(6))")
597+
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
598+
599+
def normalize_number(self, value: str, coltype: ColType) -> str:
600+
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
588601

589-
return self.to_string(f"{value}")
590602

591603

592604
class Oracle(ThreadedDatabase):
@@ -631,16 +643,15 @@ def select_table_schema(self, path: DbPath) -> str:
631643
f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'"
632644
)
633645

634-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
635-
if isinstance(coltype, TemporalType):
636-
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
637-
elif isinstance(coltype, NumericType):
638-
# FM999.9990
639-
format_str = "FM" + "9" * (38 - coltype.precision)
640-
if coltype.precision:
641-
format_str += "0." + "9" * (coltype.precision - 1) + "0"
642-
return f"to_char({value}, '{format_str}')"
643-
return self.to_string(f"{value}")
646+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
647+
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
648+
649+
def normalize_number(self, value: str, coltype: ColType) -> str:
650+
# FM999.9990
651+
format_str = "FM" + "9" * (38 - coltype.precision)
652+
if coltype.precision:
653+
format_str += "0." + "9" * (coltype.precision - 1) + "0"
654+
return f"to_char({value}, '{format_str}')"
644655

645656
def _parse_type(
646657
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
@@ -691,27 +702,33 @@ class Redshift(Postgres):
691702
def md5_to_int(self, s: str) -> str:
692703
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"
693704

694-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
695-
if isinstance(coltype, TemporalType):
696-
if coltype.rounds:
697-
timestamp = f"{value}::timestamp(6)"
698-
# Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
699-
secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)"
700-
# Get the milliseconds from timestamp.
701-
ms = f"extract(ms from {timestamp})"
702-
# Get the microseconds from timestamp, without the milliseconds!
703-
us = f"extract(us from {timestamp})"
704-
# epoch = Total time since epoch in microseconds.
705-
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
706-
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
707-
else:
708-
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
709-
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
705+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
706+
if coltype.rounds:
707+
timestamp = f"{value}::timestamp(6)"
708+
# Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
709+
secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)"
710+
# Get the milliseconds from timestamp.
711+
ms = f"extract(ms from {timestamp})"
712+
# Get the microseconds from timestamp, without the milliseconds!
713+
us = f"extract(us from {timestamp})"
714+
# epoch = Total time since epoch in microseconds.
715+
epoch = f"{secs}*1000000 + {ms}*1000 + {us}"
716+
timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
717+
else:
718+
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
719+
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
720+
721+
def normalize_number(self, value: str, coltype: ColType) -> str:
722+
return self.to_string(f"{value}::decimal(38,{coltype.precision})")
710723

711-
elif isinstance(coltype, NumericType):
712-
value = f"{value}::decimal(38,{coltype.precision})"
713724

714-
return self.to_string(f"{value}")
725+
def select_table_schema(self, path: DbPath) -> str:
726+
schema, table = self._normalize_table_path(path)
727+
728+
return (
729+
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns "
730+
f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'"
731+
)
715732

716733

717734
class MsSQL(ThreadedDatabase):
@@ -803,27 +820,23 @@ def select_table_schema(self, path: DbPath) -> str:
803820
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
804821
)
805822

806-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
807-
if isinstance(coltype, TemporalType):
808-
if coltype.rounds:
809-
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
810-
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})"
811-
else:
812-
if coltype.precision == 0:
813-
return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})"
814-
elif coltype.precision == 6:
815-
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
823+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
824+
if coltype.rounds:
825+
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
826+
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})"
816827

817-
timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
818-
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
819-
elif isinstance(coltype, Integer):
820-
pass
828+
if coltype.precision == 0:
829+
return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})"
830+
elif coltype.precision == 6:
831+
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
821832

822-
elif isinstance(coltype, NumericType):
823-
# value = f"cast({value} as decimal)"
824-
return f"format('%.{coltype.precision}f', ({value}))"
833+
timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
834+
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
825835

826-
return self.to_string(f"{value}")
836+
def normalize_number(self, value: str, coltype: ColType) -> str:
837+
if isinstance(coltype, Integer):
838+
return self.to_string(value)
839+
return f"format('%.{coltype.precision}f', {value})"
827840

828841
def parse_table_name(self, name: str) -> DbPath:
829842
path = parse_table_name(name)
@@ -897,19 +910,16 @@ def select_table_schema(self, path: DbPath) -> str:
897910
schema, table = self._normalize_table_path(path)
898911
return super().select_table_schema((schema, table))
899912

900-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
901-
if isinstance(coltype, TemporalType):
902-
if coltype.rounds:
903-
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
904-
else:
905-
timestamp = f"cast({value} as timestamp({coltype.precision}))"
906-
907-
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
913+
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
914+
if coltype.rounds:
915+
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
916+
else:
917+
timestamp = f"cast({value} as timestamp({coltype.precision}))"
908918

909-
elif isinstance(coltype, NumericType):
910-
value = f"cast({value} as decimal(38, {coltype.precision}))"
919+
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
911920

912-
return self.to_string(f"{value}")
921+
def normalize_number(self, value: str, coltype: ColType) -> str:
922+
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
913923

914924

915925
@dataclass

0 commit comments

Comments
 (0)