Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit d7ae027

Browse files
authored
Merge pull request #119 from datafold/string_key_column
Support for UUID key column
2 parents 30f5c23 + 9fcef78 commit d7ae027

File tree

14 files changed

+359
-100
lines changed

14 files changed

+359
-100
lines changed

data_diff/databases/base.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1+
from uuid import UUID
12
import math
23
import sys
34
import logging
4-
from typing import Dict, Tuple, Optional, Sequence
5+
from typing import Dict, Tuple, Optional, Sequence, Type, List
56
from functools import lru_cache, wraps
67
from concurrent.futures import ThreadPoolExecutor
78
import threading
89
from abc import abstractmethod
910

10-
from .database_types import AbstractDatabase, ColType, Integer, Decimal, Float, UnknownColType
11-
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select
11+
from data_diff.utils import is_uuid, safezip
12+
from .database_types import (
13+
ColType_UUID,
14+
AbstractDatabase,
15+
ColType,
16+
Integer,
17+
Decimal,
18+
Float,
19+
PrecisionType,
20+
TemporalType,
21+
UnknownColType,
22+
Text,
23+
)
24+
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName
1225

1326
logger = logging.getLogger("database")
1427

@@ -62,7 +75,7 @@ class Database(AbstractDatabase):
6275
Instanciated using :meth:`~data_diff.connect_to_uri`
6376
"""
6477

65-
DATETIME_TYPES: Dict[str, type] = {}
78+
TYPE_CLASSES: Dict[str, type] = {}
6679
default_schema: str = None
6780

6881
@property
@@ -93,7 +106,7 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
93106
assert len(res) == 1, (sql_code, res)
94107
return res[0]
95108
elif getattr(res_type, "__origin__", None) is list and len(res_type.__args__) == 1:
96-
if res_type.__args__ == (int,):
109+
if res_type.__args__ == (int,) or res_type.__args__ == (str,):
97110
return [_one(row) for row in res]
98111
elif res_type.__args__ == (Tuple,):
99112
return [tuple(row) for row in res]
@@ -109,8 +122,12 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
109122
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
110123
return math.floor(math.log(2**p, 10))
111124

125+
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
126+
return self.TYPE_CLASSES.get(type_repr)
127+
112128
def _parse_type(
113129
self,
130+
table_path: DbPath,
114131
col_name: str,
115132
type_repr: str,
116133
datetime_precision: int = None,
@@ -119,36 +136,38 @@ def _parse_type(
119136
) -> ColType:
120137
""" """
121138

122-
cls = self.DATETIME_TYPES.get(type_repr)
123-
if cls:
139+
cls = self._parse_type_repr(type_repr)
140+
if not cls:
141+
return UnknownColType(type_repr)
142+
143+
if issubclass(cls, TemporalType):
124144
return cls(
125145
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
126146
rounds=self.ROUNDS_ON_PREC_LOSS,
127147
)
128148

129-
cls = self.NUMERIC_TYPES.get(type_repr)
130-
if cls:
131-
if issubclass(cls, Integer):
132-
# Some DBs have a constant numeric_scale, so they don't report it.
133-
# We fill in the constant, so we need to ignore it for integers.
134-
return cls(precision=0)
135-
136-
elif issubclass(cls, Decimal):
137-
if numeric_scale is None:
138-
raise ValueError(
139-
f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}."
140-
)
141-
return cls(precision=numeric_scale)
149+
elif issubclass(cls, Integer):
150+
return cls()
151+
152+
elif issubclass(cls, Decimal):
153+
if numeric_scale is None:
154+
raise ValueError(
155+
f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
156+
)
157+
return cls(precision=numeric_scale)
142158

143-
assert issubclass(cls, Float)
159+
elif issubclass(cls, Float):
144160
# assert numeric_scale is None
145161
return cls(
146162
precision=self._convert_db_precision_to_digits(
147163
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
148164
)
149165
)
150166

151-
return UnknownColType(type_repr)
167+
elif issubclass(cls, Text):
168+
return cls()
169+
170+
raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")
152171

153172
def select_table_schema(self, path: DbPath) -> str:
154173
schema, table = self._normalize_table_path(path)
@@ -167,8 +186,34 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
167186
accept = {i.lower() for i in filter_columns}
168187
rows = [r for r in rows if r[0].lower() in accept]
169188

189+
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in rows}
190+
191+
self._refine_coltypes(path, col_dict)
192+
170193
# Return a dict of form {name: type} after normalization
171-
return {row[0]: self._parse_type(*row) for row in rows}
194+
return col_dict
195+
196+
def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
197+
"Refine the types in the column dict, by querying the database for a sample of their values"
198+
199+
text_columns = [k for k, v in col_dict.items() if isinstance(v, Text)]
200+
if not text_columns:
201+
return
202+
203+
fields = [self.normalize_uuid(c, ColType_UUID()) for c in text_columns]
204+
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
205+
samples_by_col = list(zip(*samples_by_row))
206+
for col_name, samples in safezip(text_columns, samples_by_col):
207+
uuid_samples = list(filter(is_uuid, samples))
208+
209+
if uuid_samples:
210+
if len(uuid_samples) != len(samples):
211+
logger.warning(
212+
f"Mixed UUID/Non-UUID values detected in column {'.'.join(table_path)}.{col_name}, disabling UUID support."
213+
)
214+
else:
215+
assert col_name in col_dict
216+
col_dict[col_name] = ColType_UUID()
172217

173218
# @lru_cache()
174219
# def get_table_schema(self, path: DbPath) -> Dict[str, ColType]:
@@ -186,6 +231,15 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
186231
def parse_table_name(self, name: str) -> DbPath:
187232
return parse_table_name(name)
188233

234+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
235+
if offset:
236+
raise NotImplementedError("No support for OFFSET in query")
237+
238+
return f"LIMIT {limit}"
239+
240+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
241+
return f"TRIM({value})"
242+
189243

190244
class ThreadedDatabase(Database):
191245
"""Access the database through singleton threads.

data_diff/databases/bigquery.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@ def import_bigquery():
1111

1212

1313
class BigQuery(Database):
14-
DATETIME_TYPES = {
14+
TYPE_CLASSES = {
15+
# Dates
1516
"TIMESTAMP": Timestamp,
1617
"DATETIME": Datetime,
17-
}
18-
NUMERIC_TYPES = {
18+
# Numbers
1919
"INT64": Integer,
2020
"INT32": Integer,
2121
"NUMERIC": Decimal,
2222
"BIGNUMERIC": Decimal,
2323
"FLOAT64": Float,
2424
"FLOAT32": Float,
25+
# Text
26+
"STRING": Text,
2527
}
2628
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation
2729

data_diff/databases/database_types.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1+
import decimal
12
from abc import ABC, abstractmethod
2-
from typing import Sequence, Optional, Tuple, Union, Dict
3+
from typing import Sequence, Optional, Tuple, Union, Dict, Any
34
from datetime import datetime
45

56
from runtype import dataclass
67

8+
from data_diff.utils import ArithUUID
9+
10+
711
DbPath = Tuple[str, ...]
8-
DbKey = Union[int, str, bytes]
12+
DbKey = Union[int, str, bytes, ArithUUID]
913
DbTime = datetime
1014

1115

1216
class ColType:
17+
supported = True
1318
pass
1419

1520

@@ -50,11 +55,36 @@ class Float(FractionalType):
5055

5156

5257
class Decimal(FractionalType):
58+
@property
59+
def python_type(self) -> type:
60+
if self.precision == 0:
61+
return int
62+
return decimal.Decimal
63+
64+
65+
class StringType(ColType):
5366
pass
5467

5568

69+
class IKey(ABC):
70+
"Interface for ColType, for using a column as a key in data-diff"
71+
python_type: type
72+
73+
74+
class ColType_UUID(StringType, IKey):
75+
python_type = ArithUUID
76+
77+
78+
@dataclass
79+
class Text(StringType):
80+
supported = False
81+
82+
5683
@dataclass
57-
class Integer(NumericType):
84+
class Integer(NumericType, IKey):
85+
precision: int = 0
86+
python_type: type = int
87+
5888
def __post_init__(self):
5989
assert self.precision == 0
6090

@@ -63,6 +93,8 @@ def __post_init__(self):
6393
class UnknownColType(ColType):
6494
text: str
6595

96+
supported = False
97+
6698

6799
class AbstractDatabase(ABC):
68100
@abstractmethod
@@ -80,6 +112,10 @@ def md5_to_int(self, s: str) -> str:
80112
"Provide SQL for computing md5 and returning an int"
81113
...
82114

115+
@abstractmethod
116+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
117+
...
118+
83119
@abstractmethod
84120
def _query(self, sql_code: str) -> list:
85121
"Send query to database and return result"
@@ -138,6 +174,14 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
138174
"""
139175
...
140176

177+
@abstractmethod
178+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
179+
"""Creates an SQL expression, that converts 'value' to a normalized uuid.
180+
181+
i.e. just makes sure there is no trailing whitespace.
182+
"""
183+
...
184+
141185
def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
142186
"""Creates an SQL expression, that converts 'value' to a normalized representation.
143187
@@ -158,6 +202,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
158202
return self.normalize_timestamp(value, coltype)
159203
elif isinstance(coltype, FractionalType):
160204
return self.normalize_number(value, coltype)
205+
elif isinstance(coltype, ColType_UUID):
206+
return self.normalize_uuid(value, coltype)
161207
return self.to_string(value)
162208

163209
def _normalize_table_path(self, path: DbPath) -> DbPath:

data_diff/databases/mysql.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@ def import_mysql():
1111

1212

1313
class MySQL(ThreadedDatabase):
14-
DATETIME_TYPES = {
14+
TYPE_CLASSES = {
15+
# Dates
1516
"datetime": Datetime,
1617
"timestamp": Timestamp,
17-
}
18-
NUMERIC_TYPES = {
18+
# Numbers
1919
"double": Float,
2020
"float": Float,
2121
"decimal": Decimal,
2222
"int": Integer,
2323
"bigint": Integer,
24+
# Text
25+
"varchar": Text,
26+
"char": Text,
2427
}
2528
ROUNDS_ON_PREC_LOSS = True
2629

data_diff/databases/oracle.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ def import_oracle():
1313

1414

1515
class Oracle(ThreadedDatabase):
16+
TYPE_CLASSES: Dict[str, type] = {
17+
"NUMBER": Decimal,
18+
"FLOAT": Float,
19+
# Text
20+
"CHAR": Text,
21+
"NCHAR": Text,
22+
"NVARCHAR2": Text,
23+
"VARCHAR2": Text,
24+
}
1625
ROUNDS_ON_PREC_LOSS = True
1726

1827
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
@@ -67,13 +76,13 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
6776

6877
def _parse_type(
6978
self,
79+
table_name: DbPath,
7080
col_name: str,
7181
type_repr: str,
7282
datetime_precision: int = None,
7383
numeric_precision: int = None,
7484
numeric_scale: int = None,
7585
) -> ColType:
76-
""" """
7786
regexps = {
7887
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
7988
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
@@ -87,20 +96,14 @@ def _parse_type(
8796
rounds=self.ROUNDS_ON_PREC_LOSS,
8897
)
8998

90-
n_cls = {
91-
"NUMBER": Decimal,
92-
"FLOAT": Float,
93-
}.get(type_repr, None)
94-
if n_cls:
95-
if issubclass(n_cls, Decimal):
96-
assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale)
97-
return n_cls(precision=numeric_scale)
98-
99-
assert issubclass(n_cls, Float)
100-
return n_cls(
101-
precision=self._convert_db_precision_to_digits(
102-
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
103-
)
104-
)
99+
return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)
100+
101+
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
102+
if offset:
103+
raise NotImplementedError("No support for OFFSET in query")
104+
105+
return f"FETCH NEXT {limit} ROWS ONLY"
105106

106-
return UnknownColType(type_repr)
107+
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
108+
# Cast is necessary for correct MD5 (trimming not enough)
109+
return f"CAST(TRIM({value}) AS VARCHAR(36))"

0 commit comments

Comments
 (0)