11import math
2- from functools import lru_cache
2+ from functools import lru_cache , wraps
33from itertools import zip_longest
44import re
55from abc import ABC , abstractmethod
@@ -23,6 +23,21 @@ def parse_table_name(t):
2323 return tuple (t .split ("." ))
2424
2525
26+ def import_helper (s : str ):
27+ def dec (f ):
28+ @wraps (f )
29+ def _inner ():
30+ try :
31+ return f ()
32+ except ModuleNotFoundError as e :
33+ raise ModuleNotFoundError (f"{ e } \n \n You can install it using 'pip install data-diff[{ s } ]'." )
34+
35+ return _inner
36+
37+ return dec
38+
39+
40+ @import_helper ("pgsql" )
2641def import_postgres ():
2742 import psycopg2
2843 import psycopg2 .extras
@@ -31,12 +46,14 @@ def import_postgres():
3146 return psycopg2
3247
3348
49+ @import_helper ("mysql" )
3450def import_mysql ():
3551 import mysql .connector
3652
3753 return mysql .connector
3854
3955
56+ @import_helper ("snowflake" )
4057def import_snowflake ():
4158 import snowflake .connector
4259
@@ -55,6 +72,7 @@ def import_oracle():
5572 return cx_Oracle
5673
5774
75+ @import_helper ("presto" )
5876def import_presto ():
5977 import prestodb
6078
@@ -344,7 +362,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
344362
345363 return path
346364
347-
348365 def parse_table_name (self , name : str ) -> DbPath :
349366 return parse_table_name (name )
350367
@@ -356,19 +373,25 @@ class ThreadedDatabase(Database):
356373 """
357374
358375 def __init__ (self , thread_count = 1 ):
376+ self ._init_error = None
359377 self ._queue = ThreadPoolExecutor (thread_count , initializer = self .set_conn )
360378 self .thread_local = threading .local ()
361379
362380 def set_conn (self ):
363381 assert not hasattr (self .thread_local , "conn" )
364- self .thread_local .conn = self .create_connection ()
382+ try :
383+ self .thread_local .conn = self .create_connection ()
384+ except ModuleNotFoundError as e :
385+ self ._init_error = e
365386
366387 def _query (self , sql_code : str ):
367388 r = self ._queue .submit (self ._query_in_worker , sql_code )
368389 return r .result ()
369390
370391 def _query_in_worker (self , sql_code : str ):
371392 "This method runs in a worker thread"
393+ if self ._init_error :
394+ raise self ._init_error
372395 return _query_conn (self .thread_local .conn , sql_code )
373396
374397 def close (self ):
0 commit comments