55from abc import ABC , abstractmethod
66from runtype import dataclass
77import logging
8- from typing import Tuple , Optional , List
8+ from typing import Sequence , Tuple , Optional , List
99from concurrent .futures import ThreadPoolExecutor
1010import threading
1111from typing import Dict
@@ -131,10 +131,6 @@ def __post_init__(self):
131131class UnknownColType (ColType ):
132132 text : str
133133
134- def __post_init__ (self ):
135- logger .warn (f"Column of type '{ self .text } ' has no compatibility handling. "
136- "If encoding/formatting differs between databases, it may result in false positives." )
137-
138134
139135class AbstractDatabase (ABC ):
140136 @abstractmethod
@@ -163,7 +159,7 @@ def select_table_schema(self, path: DbPath) -> str:
163159 ...
164160
165161 @abstractmethod
166- def query_table_schema (self , path : DbPath ) -> Dict [str , ColType ]:
162+ def query_table_schema (self , path : DbPath , filter_columns : Optional [ Sequence [ str ]] = None ) -> Dict [str , ColType ]:
167163 "Query the table for its schema for table in 'path', and return {column: type}"
168164 ...
169165
@@ -241,6 +237,10 @@ class Database(AbstractDatabase):
241237 DATETIME_TYPES = {}
242238 default_schema = None
243239
240+ @property
241+ def name (self ):
242+ return type (self ).__name__
243+
244244 def query (self , sql_ast : SqlOrStr , res_type : type ):
245245 "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'"
246246
@@ -321,12 +321,16 @@ def select_table_schema(self, path: DbPath) -> str:
321321 f"WHERE table_name = '{ table } ' AND table_schema = '{ schema } '"
322322 )
323323
324- def query_table_schema (self , path : DbPath ) -> Dict [str , ColType ]:
324+ def query_table_schema (self , path : DbPath , filter_columns : Optional [ Sequence [ str ]] = None ) -> Dict [str , ColType ]:
325325 rows = self .query (self .select_table_schema (path ), list )
326326 if not rows :
327- raise RuntimeError (f"{ self .__class__ .__name__ } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
327+ raise RuntimeError (f"{ self .name } : Table '{ '.' .join (path )} ' does not exist, or has no columns" )
328+
329+ if filter_columns is not None :
330+ accept = {i .lower () for i in filter_columns }
331+ rows = [r for r in rows if r [0 ].lower () in accept ]
328332
329- # Return a dict of form {name: type} after canonizaation
333+ # Return a dict of form {name: type} after normalization
330334 return {row [0 ]: self ._parse_type (* row [1 :]) for row in rows }
331335
332336 # @lru_cache()
@@ -339,7 +343,7 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
339343 return self .default_schema , path [0 ]
340344 elif len (path ) != 2 :
341345 raise ValueError (
342- f"{ self .__class__ . __name__ } : Bad table path for { self } : '{ '.' .join (path )} '. Expected form: schema.table"
346+ f"{ self .name } : Bad table path for { self } : '{ '.' .join (path )} '. Expected form: schema.table"
343347 )
344348
345349 return path
@@ -407,6 +411,7 @@ class Postgres(ThreadedDatabase):
407411 "decimal" : Decimal ,
408412 "integer" : Integer ,
409413 "numeric" : Decimal ,
414+ "bigint" : Integer ,
410415 }
411416 ROUNDS_ON_PREC_LOSS = True
412417
0 commit comments