Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 4e4958c

Browse files
committed
Better errors for missing imports
1 parent b88972a commit 4e4958c

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

data_diff/database.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from functools import lru_cache
2+
from functools import lru_cache, wraps
33
from itertools import zip_longest
44
import re
55
from abc import ABC, abstractmethod
@@ -23,6 +23,21 @@ def parse_table_name(t):
2323
return tuple(t.split("."))
2424

2525

26+
def import_helper(s: str):
27+
def dec(f):
28+
@wraps(f)
29+
def _inner():
30+
try:
31+
return f()
32+
except ModuleNotFoundError as e:
33+
raise ModuleNotFoundError(f"{e}\n\nYou can install it using 'pip install data-diff[{s}]'.")
34+
35+
return _inner
36+
37+
return dec
38+
39+
40+
@import_helper("pgsql")
2641
def import_postgres():
2742
import psycopg2
2843
import psycopg2.extras
@@ -31,12 +46,14 @@ def import_postgres():
3146
return psycopg2
3247

3348

49+
@import_helper("mysql")
3450
def import_mysql():
3551
import mysql.connector
3652

3753
return mysql.connector
3854

3955

56+
@import_helper("snowflake")
4057
def import_snowflake():
4158
import snowflake.connector
4259

@@ -55,6 +72,7 @@ def import_oracle():
5572
return cx_Oracle
5673

5774

75+
@import_helper("presto")
5876
def import_presto():
5977
import prestodb
6078

@@ -344,7 +362,6 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
344362

345363
return path
346364

347-
348365
def parse_table_name(self, name: str) -> DbPath:
349366
return parse_table_name(name)
350367

@@ -356,19 +373,25 @@ class ThreadedDatabase(Database):
356373
"""
357374

358375
def __init__(self, thread_count=1):
376+
self._init_error = None
359377
self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn)
360378
self.thread_local = threading.local()
361379

362380
def set_conn(self):
363381
assert not hasattr(self.thread_local, "conn")
364-
self.thread_local.conn = self.create_connection()
382+
try:
383+
self.thread_local.conn = self.create_connection()
384+
except ModuleNotFoundError as e:
385+
self._init_error = e
365386

366387
def _query(self, sql_code: str):
367388
r = self._queue.submit(self._query_in_worker, sql_code)
368389
return r.result()
369390

370391
def _query_in_worker(self, sql_code: str):
371392
"This method runs in a worker thread"
393+
if self._init_error:
394+
raise self._init_error
372395
return _query_conn(self.thread_local.conn, sql_code)
373396

374397
def close(self):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ mysql = ["mysql-connector-python"]
5050
pgsql = ["psycopg2"]
5151
snowflake = ["snowflake-connector-python"]
5252
presto = ["presto-python-client"]
53+
oracle = ["cx_Oracle"]
5354

5455
[build-system]
5556
requires = ["poetry-core>=1.0.0"]

0 commit comments

Comments
 (0)