55from abc import ABC , abstractmethod
66from runtype import dataclass
77import logging
8- from typing import Tuple , Optional , List
8+ from typing import Sequence , Tuple , Optional , List
99from concurrent .futures import ThreadPoolExecutor
1010import threading
1111from typing import Dict
@@ -159,10 +159,6 @@ def __post_init__(self):
159159class UnknownColType (ColType ):
160160 text : str
161161
162- def __post_init__ (self ):
163- logger .warn (f"Column of type '{ self .text } ' has no compatibility handling. "
164- "If encoding/formatting differs between databases, it may result in false positives." )
165-
166162
167163class AbstractDatabase (ABC ):
168164 @abstractmethod
@@ -191,7 +187,7 @@ def select_table_schema(self, path: DbPath) -> str:
191187 ...
192188
193189 @abstractmethod
194- def query_table_schema (self , path : DbPath ) -> Dict [str , ColType ]:
190+ def query_table_schema (self , path : DbPath , filter_columns : Optional [ Sequence [ str ]] = None ) -> Dict [str , ColType ]:
195191 "Query the table for its schema for table in 'path', and return {column: type}"
196192 ...
197193
@@ -205,7 +201,6 @@ def close(self):
205201 "Close connection(s) to the database instance. Querying will stop functioning."
206202 ...
207203
208-
209204 @abstractmethod
210205 def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
211206 """Creates an SQL expression, that converts 'value' to a normalized timestamp.
@@ -269,6 +264,10 @@ class Database(AbstractDatabase):
269264 DATETIME_TYPES = {}
270265 default_schema = None
271266
267+ @property
268+ def name (self ):
269+ return type (self ).__name__
270+
272271 def query (self , sql_ast : SqlOrStr , res_type : type ):
273272 "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
274273
@@ -310,7 +309,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
310309 return math .floor (math .log (2 ** p , 10 ))
311310
312311 def _parse_type (
313- self , type_repr : str , datetime_precision : int = None , numeric_precision : int = None , numeric_scale : int = None
312+ self ,
313+ col_name : str ,
314+ type_repr : str ,
315+ datetime_precision : int = None ,
316+ numeric_precision : int = None ,
317+ numeric_scale : int = None ,
314318 ) -> ColType :
315319 """ """
316320
@@ -329,6 +333,8 @@ def _parse_type(
329333 return cls (precision = 0 )
330334
331335 elif issubclass (cls , Decimal ):
336+ if numeric_scale is None :
337+ raise ValueError (f"{ self .name } : Unexpected numeric_scale is NULL, for column { col_name } of type { type_repr } ." )
332338 return cls (precision = numeric_scale )
333339
334340 assert issubclass (cls , Float )
@@ -349,13 +355,17 @@ def select_table_schema(self, path: DbPath) -> str:
349355 f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
350356 )
351357
352- def query_table_schema (self , path : DbPath ) -> Dict [str , ColType ]:
358+ def query_table_schema (self , path : DbPath , filter_columns : Optional [ Sequence [ str ]] = None ) -> Dict [str , ColType ]:
353359 rows = self .query (self .select_table_schema (path ), list )
354360 if not rows :
355- raise RuntimeError (f"{ self .__class__ .__name__ } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
361+ raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
362+
363+ if filter_columns is not None :
364+ accept = {i .lower () for i in filter_columns }
365+ rows = [r for r in rows if r [0 ].lower () in accept ]
356366
357- # Return a dict of form {name: type} after canonizaation
358- return {row [0 ]: self ._parse_type (* row [ 1 :] ) for row in rows }
367+ # Return a dict of form {name: type} after normalization
368+ return {row [0 ]: self ._parse_type (* row ) for row in rows }
359369
360370 # @lru_cache()
361371 # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
@@ -366,9 +376,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
366376 if self .default_schema :
367377 return self .default_schema , path [0 ]
368378 elif len (path ) != 2 :
369- raise ValueError (
370- f"{ self .__class__ .__name__ } : Bad table path for { self } : '{ '.' .join (path )} '. Expected form: schema.table"
371- )
379+ raise ValueError (f"{ self .name } : Bad table path for { self } : '{ '.' .join (path )} '. Expected form: schema.table" )
372380
373381 return path
374382
@@ -440,6 +448,7 @@ class PostgreSQL(ThreadedDatabase):
440448 "decimal" : Decimal ,
441449 "integer" : Integer ,
442450 "numeric" : Decimal ,
451+ "bigint" : Integer ,
443452 }
444453 ROUNDS_ON_PREC_LOSS = True
445454
@@ -472,13 +481,14 @@ def md5_to_int(self, s: str) -> str:
472481 def to_string (self , s : str ):
473482 return f"{ s } ::varchar"
474483
475-
476484 def normalize_timestamp (self , value : str , coltype : ColType ) -> str :
477485 if coltype .rounds :
478486 return f"to_char({ value } ::timestamp({ coltype .precision } ), 'YYYY-mm-dd HH24:MI:SS.US')"
479487
480488 timestamp6 = f"to_char({ value } ::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
481- return f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
489+ return (
490+ f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
491+ )
482492
483493 def normalize_number (self , value : str , coltype : ColType ) -> str :
484494 return self .to_string (f"{ value } ::decimal(38, { coltype .precision } )" )
@@ -528,9 +538,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
528538 else :
529539 s = f"date_format(cast({ value } as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
530540
531- return (
532- f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
533- )
541+ return f"RPAD(RPAD({ s } , { TIMESTAMP_PRECISION_POS + coltype .precision } , '.'), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
534542
535543 def normalize_number (self , value : str , coltype : ColType ) -> str :
536544 return self .to_string (f"cast({ value } as decimal(38,{ coltype .precision } ))" )
@@ -543,7 +551,9 @@ def select_table_schema(self, path: DbPath) -> str:
543551 f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
544552 )
545553
546- def _parse_type (self , type_repr : str , datetime_precision : int = None , numeric_precision : int = None ) -> ColType :
554+ def _parse_type (
555+ self , col_name : str , type_repr : str , datetime_precision : int = None , numeric_precision : int = None
556+ ) -> ColType :
547557 regexps = {
548558 r"timestamp\((\d)\)" : Timestamp ,
549559 r"timestamp\((\d)\) with time zone" : TimestampTZ ,
@@ -633,7 +643,6 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
633643 return self .to_string (f"cast({ value } as decimal(38, { coltype .precision } ))" )
634644
635645
636-
637646class Oracle (ThreadedDatabase ):
638647 ROUNDS_ON_PREC_LOSS = True
639648
@@ -687,7 +696,12 @@ def normalize_number(self, value: str, coltype: ColType) -> str:
687696 return f"to_char({ value } , '{ format_str } ')"
688697
689698 def _parse_type (
690- self , type_repr : str , datetime_precision : int = None , numeric_precision : int = None , numeric_scale : int = None
699+ self ,
700+ col_name : str ,
701+ type_repr : str ,
702+ datetime_precision : int = None ,
703+ numeric_precision : int = None ,
704+ numeric_scale : int = None ,
691705 ) -> ColType :
692706 """ """
693707 regexps = {
@@ -746,15 +760,18 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
746760 us = f"extract(us from { timestamp } )"
747761 # epoch = Total time since epoch in microseconds.
748762 epoch = f"{ secs } *1000000 + { ms } *1000 + { us } "
749- timestamp6 = f"to_char({ epoch } , -6+{ coltype .precision } ) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
763+ timestamp6 = (
764+ f"to_char({ epoch } , -6+{ coltype .precision } ) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')"
765+ )
750766 else :
751767 timestamp6 = f"to_char({ value } ::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
752- return f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
768+ return (
769+ f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
770+ )
753771
754772 def normalize_number (self , value : str , coltype : ColType ) -> str :
755773 return self .to_string (f"{ value } ::decimal(38,{ coltype .precision } )" )
756774
757-
758775 def select_table_schema (self , path : DbPath ) -> str :
759776 schema , table = self ._normalize_table_path (path )
760777
@@ -864,7 +881,9 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
864881 return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', { value } )"
865882
866883 timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', { value } )"
867- return f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
884+ return (
885+ f"RPAD(LEFT({ timestamp6 } , { TIMESTAMP_PRECISION_POS + coltype .precision } ), { TIMESTAMP_PRECISION_POS + 6 } , '0')"
886+ )
868887
869888 def normalize_number (self , value : str , coltype : ColType ) -> str :
870889 if isinstance (coltype , Integer ):
0 commit comments