Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e3fda00

Browse files
committed
Refactor _parse_type()
1 parent 6ed2a0e commit e3fda00

File tree

9 files changed

+63
-72
lines changed

9 files changed

+63
-72
lines changed

data_diff/databases/base.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import math
22
import sys
33
import logging
4-
from typing import Dict, Tuple, Optional, Sequence
4+
from typing import Dict, Tuple, Optional, Sequence, Type
55
from functools import lru_cache, wraps
66
from concurrent.futures import ThreadPoolExecutor
77
import threading
88
from abc import abstractmethod
99

10-
from .database_types import AbstractDatabase, ColType, Integer, Decimal, Float, UnknownColType
10+
from .database_types import (
11+
AbstractDatabase,
12+
ColType,
13+
Integer,
14+
Decimal,
15+
Float,
16+
PrecisionType,
17+
TemporalType,
18+
UnknownColType,
19+
)
1120
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select
1221

1322
logger = logging.getLogger("database")
@@ -62,7 +71,7 @@ class Database(AbstractDatabase):
6271
Instanciated using :meth:`~data_diff.connect_to_uri`
6372
"""
6473

65-
DATETIME_TYPES: Dict[str, type] = {}
74+
TYPE_CLASSES: Dict[str, type] = {}
6675
default_schema: str = None
6776

6877
@property
@@ -109,6 +118,9 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
109118
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
110119
return math.floor(math.log(2**p, 10))
111120

121+
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
122+
return self.TYPE_CLASSES.get(type_repr)
123+
112124
def _parse_type(
113125
self,
114126
col_name: str,
@@ -119,36 +131,35 @@ def _parse_type(
119131
) -> ColType:
120132
""" """
121133

122-
cls = self.DATETIME_TYPES.get(type_repr)
123-
if cls:
134+
cls = self._parse_type_repr(type_repr)
135+
if not cls:
136+
return UnknownColType(type_repr)
137+
138+
if issubclass(cls, TemporalType):
124139
return cls(
125140
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
126141
rounds=self.ROUNDS_ON_PREC_LOSS,
127142
)
128143

129-
cls = self.NUMERIC_TYPES.get(type_repr)
130-
if cls:
131-
if issubclass(cls, Integer):
132-
# Some DBs have a constant numeric_scale, so they don't report it.
133-
# We fill in the constant, so we need to ignore it for integers.
134-
return cls(precision=0)
135-
136-
elif issubclass(cls, Decimal):
137-
if numeric_scale is None:
138-
raise ValueError(
139-
f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}."
140-
)
141-
return cls(precision=numeric_scale)
142-
143-
assert issubclass(cls, Float)
144+
elif issubclass(cls, Integer):
145+
return cls()
146+
147+
elif issubclass(cls, Decimal):
148+
if numeric_scale is None:
149+
raise ValueError(
150+
f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}."
151+
)
152+
return cls(precision=numeric_scale)
153+
154+
elif issubclass(cls, Float):
144155
# assert numeric_scale is None
145156
return cls(
146157
precision=self._convert_db_precision_to_digits(
147158
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
148159
)
149160
)
150161

151-
return UnknownColType(type_repr)
162+
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
152163

153164
def select_table_schema(self, path: DbPath) -> str:
154165
schema, table = self._normalize_table_path(path)

data_diff/databases/bigquery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def import_bigquery():
1111

1212

1313
class BigQuery(Database):
14-
DATETIME_TYPES = {
14+
TYPE_CLASSES = {
15+
# Dates
1516
"TIMESTAMP": Timestamp,
1617
"DATETIME": Datetime,
17-
}
18-
NUMERIC_TYPES = {
18+
# Numbers
1919
"INT64": Integer,
2020
"INT32": Integer,
2121
"NUMERIC": Decimal,

data_diff/databases/database_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class Decimal(FractionalType):
5353
pass
5454

5555

56+
class StringType(ColType):
57+
pass
58+
59+
60+
class UUID(StringType):
61+
pass
62+
63+
5664
@dataclass
5765
class Integer(NumericType):
5866
def __post_init__(self):

data_diff/databases/mysql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def import_mysql():
1111

1212

1313
class MySQL(ThreadedDatabase):
14-
DATETIME_TYPES = {
14+
TYPE_CLASSES = {
15+
# Dates
1516
"datetime": Datetime,
1617
"timestamp": Timestamp,
17-
}
18-
NUMERIC_TYPES = {
18+
# Numbers
1919
"double": Float,
2020
"float": Float,
2121
"decimal": Decimal,

data_diff/databases/oracle.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def import_oracle():
1313

1414

1515
class Oracle(ThreadedDatabase):
16+
TYPE_CLASSES: Dict[str, type] = {
17+
"NUMBER": Decimal,
18+
"FLOAT": Float,
19+
}
1620
ROUNDS_ON_PREC_LOSS = True
1721

1822
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
@@ -72,7 +76,6 @@ def _parse_type(
7276
numeric_precision: int = None,
7377
numeric_scale: int = None,
7478
) -> ColType:
75-
""" """
7679
regexps = {
7780
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
7881
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
@@ -86,20 +89,4 @@ def _parse_type(
8689
rounds=self.ROUNDS_ON_PREC_LOSS,
8790
)
8891

89-
n_cls = {
90-
"NUMBER": Decimal,
91-
"FLOAT": Float,
92-
}.get(type_repr, None)
93-
if n_cls:
94-
if issubclass(n_cls, Decimal):
95-
assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale)
96-
return n_cls(precision=numeric_scale)
97-
98-
assert issubclass(n_cls, Float)
99-
return n_cls(
100-
precision=self._convert_db_precision_to_digits(
101-
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
102-
)
103-
)
104-
105-
return UnknownColType(type_repr)
92+
return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)

data_diff/databases/postgresql.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@ def import_postgresql():
1313

1414

1515
class PostgreSQL(ThreadedDatabase):
16-
DATETIME_TYPES = {
16+
TYPE_CLASSES = {
17+
# Timestamps
1718
"timestamp with time zone": TimestampTZ,
1819
"timestamp without time zone": Timestamp,
1920
"timestamp": Timestamp,
20-
# "datetime": Datetime,
21-
}
22-
NUMERIC_TYPES = {
21+
# Numbers
2322
"double precision": Float,
2423
"real": Float,
2524
"decimal": Decimal,

data_diff/databases/presto.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ def import_presto():
2020

2121
class Presto(Database):
2222
default_schema = "public"
23-
DATETIME_TYPES = {
23+
TYPE_CLASSES = {
24+
# Timestamps
2425
"timestamp with time zone": TimestampTZ,
2526
"timestamp without time zone": Timestamp,
2627
"timestamp": Timestamp,
27-
# "datetime": Datetime,
28-
}
29-
NUMERIC_TYPES = {
28+
# Numbers
3029
"integer": Integer,
3130
"real": Float,
3231
"double": Float,
@@ -104,17 +103,4 @@ def _parse_type(
104103
prec, scale = map(int, m.groups())
105104
return n_cls(scale)
106105

107-
n_cls = self.NUMERIC_TYPES.get(type_repr)
108-
if n_cls:
109-
if issubclass(n_cls, Integer):
110-
assert numeric_precision is not None
111-
return n_cls(0)
112-
113-
assert issubclass(n_cls, Float)
114-
return n_cls(
115-
precision=self._convert_db_precision_to_digits(
116-
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
117-
)
118-
)
119-
120-
return UnknownColType(type_repr)
106+
return super()._parse_type(type_repr)

data_diff/databases/redshift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44

55
class Redshift(PostgreSQL):
6-
NUMERIC_TYPES = {
7-
**PostgreSQL.NUMERIC_TYPES,
6+
TYPE_CLASSES = {
7+
**PostgreSQL.TYPE_CLASSES,
88
"double": Float,
99
"real": Float,
1010
}

data_diff/databases/snowflake.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ def import_snowflake():
1212

1313

1414
class Snowflake(Database):
15-
DATETIME_TYPES = {
15+
TYPE_CLASSES = {
16+
# Timestamps
1617
"TIMESTAMP_NTZ": Timestamp,
1718
"TIMESTAMP_LTZ": Timestamp,
1819
"TIMESTAMP_TZ": TimestampTZ,
19-
}
20-
NUMERIC_TYPES = {
20+
# Numbers
2121
"NUMBER": Decimal,
2222
"FLOAT": Float,
2323
}

0 commit comments

Comments
 (0)