Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 19a8ed6

Browse files
authored
Merge pull request #98 from datafold/refactor_june23
Move common ABCs and types to database_types.py; Fix type annotations
2 parents 1ee2f4b + 4807627 commit 19a8ed6

File tree

7 files changed

+278
-260
lines changed

7 files changed

+278
-260
lines changed

data_diff/database.py

Lines changed: 26 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
from itertools import zip_longest
44
import re
55
from abc import ABC, abstractmethod
6-
from runtype import dataclass
76
import logging
8-
from typing import Sequence, Tuple, Optional, List
7+
from typing import Sequence, Tuple, Optional, List, Type
98
from concurrent.futures import ThreadPoolExecutor
109
import threading
1110
from typing import Dict
12-
1311
import dsnparse
1412
import sys
1513

14+
from runtype import dataclass
15+
1616
from .sql import DbPath, SqlOrStr, Compiler, Explain, Select
17+
from .database_types import *
1718

1819

1920
logger = logging.getLogger("database")
@@ -109,149 +110,6 @@ def _query_conn(conn, sql_code: str) -> list:
109110
return c.fetchall()
110111

111112

112-
class ColType:
113-
pass
114-
115-
116-
@dataclass
117-
class PrecisionType(ColType):
118-
precision: Optional[int]
119-
rounds: bool
120-
121-
122-
class TemporalType(PrecisionType):
123-
pass
124-
125-
126-
class Timestamp(TemporalType):
127-
pass
128-
129-
130-
class TimestampTZ(TemporalType):
131-
pass
132-
133-
134-
class Datetime(TemporalType):
135-
pass
136-
137-
138-
@dataclass
139-
class NumericType(ColType):
140-
# 'precision' signifies how many fractional digits (after the dot) we want to compare
141-
precision: int
142-
143-
144-
class Float(NumericType):
145-
pass
146-
147-
148-
class Decimal(NumericType):
149-
pass
150-
151-
152-
@dataclass
153-
class Integer(Decimal):
154-
def __post_init__(self):
155-
assert self.precision == 0
156-
157-
158-
@dataclass
159-
class UnknownColType(ColType):
160-
text: str
161-
162-
163-
class AbstractDatabase(ABC):
164-
@abstractmethod
165-
def quote(self, s: str):
166-
"Quote SQL name (implementation specific)"
167-
...
168-
169-
@abstractmethod
170-
def to_string(self, s: str) -> str:
171-
"Provide SQL for casting a column to string"
172-
...
173-
174-
@abstractmethod
175-
def md5_to_int(self, s: str) -> str:
176-
"Provide SQL for computing md5 and returning an int"
177-
...
178-
179-
@abstractmethod
180-
def _query(self, sql_code: str) -> list:
181-
"Send query to database and return result"
182-
...
183-
184-
@abstractmethod
185-
def select_table_schema(self, path: DbPath) -> str:
186-
"Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"
187-
...
188-
189-
@abstractmethod
190-
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
191-
"Query the table for its schema for table in 'path', and return {column: type}"
192-
...
193-
194-
@abstractmethod
195-
def parse_table_name(self, name: str) -> DbPath:
196-
"Parse the given table name into a DbPath"
197-
...
198-
199-
@abstractmethod
200-
def close(self):
201-
"Close connection(s) to the database instance. Querying will stop functioning."
202-
...
203-
204-
@abstractmethod
205-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
206-
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.
207-
208-
The returned expression must accept any SQL datetime/timestamp, and return a string.
209-
210-
Date format: "YYYY-MM-DD HH:mm:SS.FFFFFF"
211-
212-
Precision of dates should be rounded up/down according to coltype.rounds
213-
"""
214-
...
215-
216-
@abstractmethod
217-
def normalize_number(self, value: str, coltype: ColType) -> str:
218-
"""Creates an SQL expression, that converts 'value' to a normalized number.
219-
220-
The returned expression must accept any SQL int/numeric/float, and return a string.
221-
222-
- Floats/Decimals are expected in the format
223-
"I.P"
224-
225-
Where I is the integer part of the number (as many digits as necessary),
226-
and must be at least one digit (0).
227-
P is the fractional digits, the amount of which is specified with
228-
coltype.precision. Trailing zeroes may be necessary.
229-
If P is 0, the dot is omitted.
230-
231-
Note: This precision is different than the one used by databases. For decimals,
232-
it's the same as ``numeric_scale``, and for floats, who use binary precision,
233-
it can be calculated as ``log10(2**numeric_precision)``.
234-
"""
235-
...
236-
237-
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
238-
"""Creates an SQL expression, that converts 'value' to a normalized representation.
239-
240-
The returned expression must accept any SQL value, and return a string.
241-
242-
The default implementation dispatches to a method according to ``coltype``:
243-
244-
TemporalType -> normalize_timestamp()
245-
NumericType -> normalize_number()
246-
-else- -> to_string()
247-
248-
"""
249-
if isinstance(coltype, TemporalType):
250-
return self.normalize_timestamp(value, coltype)
251-
elif isinstance(coltype, NumericType):
252-
return self.normalize_number(value, coltype)
253-
return self.to_string(f"{value}")
254-
255113

256114
class Database(AbstractDatabase):
257115
"""Base abstract class for databases.
@@ -261,8 +119,8 @@ class Database(AbstractDatabase):
261119
Instanciated using :meth:`~data_diff.connect_to_uri`
262120
"""
263121

264-
DATETIME_TYPES = {}
265-
default_schema = None
122+
DATETIME_TYPES: Dict[str, type] = {}
123+
default_schema: str = None
266124

267125
@property
268126
def name(self):
@@ -412,9 +270,6 @@ def _query_in_worker(self, sql_code: str):
412270
raise self._init_error
413271
return _query_conn(self.thread_local.conn, sql_code)
414272

415-
def close(self):
416-
self._queue.shutdown(True)
417-
418273
@abstractmethod
419274
def create_connection(self):
420275
...
@@ -481,7 +336,7 @@ def md5_to_int(self, s: str) -> str:
481336
def to_string(self, s: str):
482337
return f"{s}::varchar"
483338

484-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
339+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
485340
if coltype.rounds:
486341
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"
487342

@@ -490,7 +345,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
490345
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
491346
)
492347

493-
def normalize_number(self, value: str, coltype: ColType) -> str:
348+
def normalize_number(self, value: str, coltype: NumericType) -> str:
494349
return self.to_string(f"{value}::decimal(38, {coltype.precision})")
495350

496351

@@ -531,7 +386,7 @@ def _query(self, sql_code: str) -> list:
531386
def close(self):
532387
self._conn.close()
533388

534-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
389+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
535390
# TODO
536391
if coltype.rounds:
537392
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
@@ -540,7 +395,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
540395

541396
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
542397

543-
def normalize_number(self, value: str, coltype: ColType) -> str:
398+
def normalize_number(self, value: str, coltype: NumericType) -> str:
544399
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
545400

546401
def select_table_schema(self, path: DbPath) -> str:
@@ -554,11 +409,11 @@ def select_table_schema(self, path: DbPath) -> str:
554409
def _parse_type(
555410
self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None
556411
) -> ColType:
557-
regexps = {
412+
timestamp_regexps = {
558413
r"timestamp\((\d)\)": Timestamp,
559414
r"timestamp\((\d)\) with time zone": TimestampTZ,
560415
}
561-
for regexp, cls in regexps.items():
416+
for regexp, cls in timestamp_regexps.items():
562417
m = re.match(regexp + "$", type_repr)
563418
if m:
564419
datetime_precision = int(m.group(1))
@@ -567,8 +422,8 @@ def _parse_type(
567422
rounds=False,
568423
)
569424

570-
regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
571-
for regexp, cls in regexps.items():
425+
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
426+
for regexp, cls in number_regexps.items():
572427
m = re.match(regexp + "$", type_repr)
573428
if m:
574429
prec, scale = map(int, m.groups())
@@ -632,14 +487,14 @@ def md5_to_int(self, s: str) -> str:
632487
def to_string(self, s: str):
633488
return f"cast({s} as char)"
634489

635-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
490+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
636491
if coltype.rounds:
637492
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")
638493

639494
s = self.to_string(f"cast({value} as datetime(6))")
640495
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
641496

642-
def normalize_number(self, value: str, coltype: ColType) -> str:
497+
def normalize_number(self, value: str, coltype: NumericType) -> str:
643498
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
644499

645500

@@ -685,10 +540,10 @@ def select_table_schema(self, path: DbPath) -> str:
685540
f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'"
686541
)
687542

688-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
543+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
689544
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
690545

691-
def normalize_number(self, value: str, coltype: ColType) -> str:
546+
def normalize_number(self, value: str, coltype: NumericType) -> str:
692547
# FM999.9990
693548
format_str = "FM" + "9" * (38 - coltype.precision)
694549
if coltype.precision:
@@ -749,7 +604,7 @@ class Redshift(PostgreSQL):
749604
def md5_to_int(self, s: str) -> str:
750605
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"
751606

752-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
607+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
753608
if coltype.rounds:
754609
timestamp = f"{value}::timestamp(6)"
755610
# Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
@@ -769,7 +624,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
769624
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
770625
)
771626

772-
def normalize_number(self, value: str, coltype: ColType) -> str:
627+
def normalize_number(self, value: str, coltype: NumericType) -> str:
773628
return self.to_string(f"{value}::decimal(38,{coltype.precision})")
774629

775630
def select_table_schema(self, path: DbPath) -> str:
@@ -870,7 +725,7 @@ def select_table_schema(self, path: DbPath) -> str:
870725
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
871726
)
872727

873-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
728+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
874729
if coltype.rounds:
875730
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
876731
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})"
@@ -885,7 +740,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
885740
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
886741
)
887742

888-
def normalize_number(self, value: str, coltype: ColType) -> str:
743+
def normalize_number(self, value: str, coltype: NumericType) -> str:
889744
if isinstance(coltype, Integer):
890745
return self.to_string(value)
891746
return f"format('%.{coltype.precision}f', {value})"
@@ -962,21 +817,21 @@ def select_table_schema(self, path: DbPath) -> str:
962817
schema, table = self._normalize_table_path(path)
963818
return super().select_table_schema((schema, table))
964819

965-
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
820+
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
966821
if coltype.rounds:
967822
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
968823
else:
969824
timestamp = f"cast({value} as timestamp({coltype.precision}))"
970825

971826
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"
972827

973-
def normalize_number(self, value: str, coltype: ColType) -> str:
828+
def normalize_number(self, value: str, coltype: NumericType) -> str:
974829
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")
975830

976831

977832
@dataclass
978833
class MatchUriPath:
979-
database_cls: type
834+
database_cls: Type[Database]
980835
params: List[str]
981836
kwparams: List[str] = []
982837
help_str: str
@@ -1027,7 +882,7 @@ def match_path(self, dsn):
1027882
"postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://<user>:<pass>@<host>/<database>"),
1028883
"mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://<user>:<pass>@<host>/<database>"),
1029884
"oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://<user>:<pass>@<host>/<database>"),
1030-
"mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://<user>:<pass>@<host>/<database>"),
885+
# "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://<user>:<pass>@<host>/<database>"),
1031886
"redshift": MatchUriPath(Redshift, ["database?"], help_str="redshift://<user>:<pass>@<host>/<database>"),
1032887
"snowflake": MatchUriPath(
1033888
Snowflake,
@@ -1055,7 +910,6 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
1055910
Supported schemes:
1056911
- postgresql
1057912
- mysql
1058-
- mssql
1059913
- oracle
1060914
- snowflake
1061915
- bigquery

0 commit comments

Comments
 (0)