Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit e10e59d

Browse files
authored
Merge pull request #185 from datafold/wildcard_columns
Added support for auto-detecting mutual columns, and using patterns in -c
2 parents b1bebee + 18f65ee commit e10e59d

File tree

9 files changed

+195
-110
lines changed

9 files changed

+195
-110
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ If a database is not on the list, we'd still love to support it. Open an issue
144144
to discuss it.
145145

146146
Note: Because URLs allow many special characters, and may collide with the syntax of your command-line,
147-
it's recommended to surround them with quotes. Alternatively, you may provide them in a TOML file via the `--config` option.
147+
it's recommended to surround them with quotes. Alternatively, you may provide them in a TOML file via the `--config` option.
148148

149149

150150
# How to install
@@ -195,7 +195,7 @@ Options:
195195
- `--help` - Show help message and exit.
196196
- `-k` or `--key-column` - Name of the primary key column
197197
- `-t` or `--update-column` - Name of updated_at/last_updated column
198-
- `-c` or `--columns` - List of names of extra columns to compare
198+
- `-c` or `--columns` - Name or pattern of extra columns to compare. Pattern syntax is like SQL, e.g. `%foob.r%`.
199199
- `-l` or `--limit` - Maximum number of differences to find (limits maximum bandwidth and runtime)
200200
- `-s` or `--stats` - Print stats instead of a detailed diff
201201
- `-d` or `--debug` - Print debug info

data_diff/__main__.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import logging
66
from itertools import islice
77

8-
from .utils import remove_password_from_url
8+
from .utils import remove_password_from_url, safezip, match_like
99

1010
from .diff_tables import (
1111
TableSegment,
1212
TableDiffer,
1313
DEFAULT_BISECTION_THRESHOLD,
1414
DEFAULT_BISECTION_FACTOR,
15+
create_schema,
1516
)
1617
from .databases.connect import connect
1718
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
@@ -39,6 +40,11 @@ def _remove_passwords_in_dict(d: dict):
3940
d[k] = remove_password_from_url(v)
4041

4142

43+
def _get_schema(pair):
44+
db, table_path = pair
45+
return db.query_table_schema(table_path)
46+
47+
4248
@click.command()
4349
@click.argument("database1", required=False)
4450
@click.argument("table1", required=False)
@@ -67,7 +73,11 @@ def _remove_passwords_in_dict(d: dict):
6773
@click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability")
6874
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
6975
@click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug")
70-
@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.")
76+
@click.option(
77+
"--case-sensitive",
78+
is_flag=True,
79+
help="Column names are treated as case-sensitive. Otherwise, data-diff corrects their case according to schema.",
80+
)
7181
@click.option(
7282
"-j",
7383
"--threads",
@@ -111,7 +121,7 @@ def _main(
111121
verbose,
112122
interactive,
113123
threads,
114-
keep_column_case,
124+
case_sensitive,
115125
json_output,
116126
where,
117127
threads1=None,
@@ -158,35 +168,66 @@ def _main(
158168

159169
db1 = connect(database1, threads1 or threads)
160170
db2 = connect(database2, threads2 or threads)
171+
dbs = db1, db2
161172

162173
if interactive:
163-
db1.enable_interactive()
164-
db2.enable_interactive()
174+
for db in dbs:
175+
db.enable_interactive()
165176

166177
start = time.time()
167178

168179
try:
169180
options = dict(
170181
min_update=max_age and parse_time_before_now(max_age),
171182
max_update=min_age and parse_time_before_now(min_age),
172-
case_sensitive=keep_column_case,
183+
case_sensitive=case_sensitive,
173184
where=where,
174185
)
175186
except ParseError as e:
176187
logging.error("Error while parsing age expression: %s" % e)
177188
return
178189

179-
table1_seg = TableSegment(db1, db1.parse_table_name(table1), key_column, update_column, columns, **options)
180-
table2_seg = TableSegment(db2, db2.parse_table_name(table2), key_column, update_column, columns, **options)
181-
182190
differ = TableDiffer(
183191
bisection_factor=bisection_factor,
184192
bisection_threshold=bisection_threshold,
185193
threaded=threaded,
186194
max_threadpool_size=threads and threads * 2,
187195
debug=debug,
188196
)
189-
diff_iter = differ.diff_tables(table1_seg, table2_seg)
197+
198+
table_names = table1, table2
199+
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
200+
201+
schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
202+
schema1, schema2 = schemas = [
203+
create_schema(db, table_path, schema, case_sensitive)
204+
for db, table_path, schema in safezip(dbs, table_paths, schemas)
205+
]
206+
207+
mutual = schema1.keys() & schema2.keys() # Case-aware, according to case_sensitive
208+
logging.debug(f"Available mutual columns: {mutual}")
209+
210+
expanded_columns = set()
211+
for c in columns:
212+
match = set(match_like(c, mutual))
213+
if not match:
214+
m1 = None if any(match_like(c, schema1.keys())) else f"{db1}/{table1}"
215+
m2 = None if any(match_like(c, schema2.keys())) else f"{db2}/{table2}"
216+
not_matched = ", ".join(m for m in [m1, m2] if m)
217+
raise ValueError(f"Column {c} not found in: {not_matched}")
218+
219+
expanded_columns |= match
220+
221+
columns = tuple(expanded_columns - {key_column, update_column})
222+
223+
logging.info(f"Diffing columns: key={key_column} update={update_column} extra={columns}")
224+
225+
segments = [
226+
TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema)
227+
for db, table_path, raw_schema in safezip(dbs, table_paths, schemas)
228+
]
229+
230+
diff_iter = differ.diff_tables(*segments)
190231

191232
if limit:
192233
diff_iter = islice(diff_iter, int(limit))

data_diff/databases/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import threading
88
from abc import abstractmethod
99

10-
from data_diff.utils import is_uuid, safezip
10+
from data_diff.utils import CaseAwareMapping, is_uuid, safezip
1111
from .database_types import (
1212
AbstractDatabase,
1313
ColType,
@@ -180,16 +180,19 @@ def select_table_schema(self, path: DbPath) -> str:
180180
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
181181
)
182182

183-
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
183+
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
184184
rows = self.query(self.select_table_schema(path), list)
185185
if not rows:
186186
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
187187

188-
if filter_columns is not None:
189-
accept = {i.lower() for i in filter_columns}
190-
rows = [r for r in rows if r[0].lower() in accept]
188+
d = {r[0]: r for r in rows}
189+
assert len(d) == len(rows)
190+
return d
191191

192-
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in rows}
192+
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
193+
accept = {i.lower() for i in filter_columns}
194+
195+
col_dict = {name: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept}
193196

194197
self._refine_coltypes(path, col_dict)
195198

data_diff/databases/database_types.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from runtype import dataclass
77

8-
from data_diff.utils import ArithAlphanumeric, ArithUUID, ArithString
8+
from data_diff.utils import ArithAlphanumeric, ArithUUID, CaseAwareMapping
99

1010

1111
DbPath = Tuple[str, ...]
@@ -171,10 +171,23 @@ def select_table_schema(self, path: DbPath) -> str:
171171
...
172172

173173
@abstractmethod
174-
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
175-
"Query the table for its schema for table in 'path', and return {column: type}"
174+
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
175+
"""Query the table for its schema for table in 'path', and return {column: tuple}
176+
where the tuple is (table_name, col_name, type_repr, datetime_precision?, numeric_precision?, numeric_scale?)
177+
"""
176178
...
177179

180+
@abstractmethod
181+
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
182+
"""Process the result of query_table_schema().
183+
184+
Done in a separate step, to minimize the amount of processed columns.
185+
Needed because processing each column may:
186+
* throw errors and warnings
187+
* query the database to sample values
188+
189+
"""
190+
178191
@abstractmethod
179192
def parse_table_name(self, name: str) -> DbPath:
180193
"Parse the given table name into a DbPath"
@@ -254,44 +267,4 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
254267
...
255268

256269

257-
class Schema(ABC):
258-
@abstractmethod
259-
def get_key(self, key: str) -> str:
260-
...
261-
262-
@abstractmethod
263-
def __getitem__(self, key: str) -> ColType:
264-
...
265-
266-
@abstractmethod
267-
def __setitem__(self, key: str, value):
268-
...
269-
270-
@abstractmethod
271-
def __contains__(self, key: str) -> bool:
272-
...
273-
274-
275-
class Schema_CaseSensitive(dict, Schema):
276-
def get_key(self, key):
277-
return key
278-
279-
280-
class Schema_CaseInsensitive(Schema):
281-
def __init__(self, initial):
282-
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}
283-
284-
def get_key(self, key: str) -> str:
285-
return self._dict[key.lower()][0]
286-
287-
def __getitem__(self, key: str) -> ColType:
288-
return self._dict[key.lower()][1]
289-
290-
def __setitem__(self, key: str, value):
291-
k = key.lower()
292-
if k in self._dict:
293-
key = self._dict[k][0]
294-
self._dict[k] = key, value
295-
296-
def __contains__(self, key):
297-
return key.lower() in self._dict
270+
Schema = CaseAwareMapping

data_diff/databases/databricks.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
6868
# Subtracting 1 due to wierd precision issues
6969
return max(super()._convert_db_precision_to_digits(p) - 1, 0)
7070

71-
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
71+
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
7272
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
7373
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
7474
# So, to obtain information about schema, we should use another approach.
@@ -80,35 +80,40 @@ def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str
8080
if not rows:
8181
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
8282

83-
if filter_columns is not None:
84-
accept = {i.lower() for i in filter_columns}
85-
rows = [r for r in rows if r.COLUMN_NAME.lower() in accept]
83+
d = {r.COLUMN_NAME: r for r in rows}
84+
assert len(d) == len(rows)
85+
return d
8686

87-
resulted_rows = []
88-
for row in rows:
89-
row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME
90-
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
87+
def _process_table_schema(self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str]):
88+
accept = {i.lower() for i in filter_columns}
89+
rows = [row for name, row in raw_schema.items() if name.lower() in accept]
9190

92-
if issubclass(type_cls, Integer):
93-
row = (row.COLUMN_NAME, row_type, None, None, 0)
91+
resulted_rows = []
92+
for row in rows:
93+
row_type = "DECIMAL" if row.DATA_TYPE == 3 else row.TYPE_NAME
94+
type_cls = self.TYPE_CLASSES.get(row_type, UnknownColType)
9495

95-
elif issubclass(type_cls, Float):
96-
numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS)
97-
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)
96+
if issubclass(type_cls, Integer):
97+
row = (row.COLUMN_NAME, row_type, None, None, 0)
9898

99-
elif issubclass(type_cls, Decimal):
100-
# TYPE_NAME has a format DECIMAL(x,y)
101-
items = row.TYPE_NAME[8:].rstrip(")").split(",")
102-
numeric_precision, numeric_scale = int(items[0]), int(items[1])
103-
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)
99+
elif issubclass(type_cls, Float):
100+
numeric_precision = self._convert_db_precision_to_digits(row.DECIMAL_DIGITS)
101+
row = (row.COLUMN_NAME, row_type, None, numeric_precision, None)
104102

105-
elif issubclass(type_cls, Timestamp):
106-
row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None)
103+
elif issubclass(type_cls, Decimal):
104+
# TYPE_NAME has a format DECIMAL(x,y)
105+
items = row.TYPE_NAME[8:].rstrip(")").split(",")
106+
numeric_precision, numeric_scale = int(items[0]), int(items[1])
107+
row = (row.COLUMN_NAME, row_type, None, numeric_precision, numeric_scale)
107108

108-
else:
109-
row = (row.COLUMN_NAME, row_type, None, None, None)
109+
elif issubclass(type_cls, Timestamp):
110+
row = (row.COLUMN_NAME, row_type, row.DECIMAL_DIGITS, None, None)
111+
112+
else:
113+
row = (row.COLUMN_NAME, row_type, None, None, None)
114+
115+
resulted_rows.append(row)
110116

111-
resulted_rows.append(row)
112117
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in resulted_rows}
113118

114119
self._refine_coltypes(path, col_dict)

data_diff/databases/mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def create_connection(self):
4747
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
4848
raise ConnectError("Database does not exist") from e
4949
else:
50-
raise ConnectError(*e._args) from e
50+
raise ConnectError(*e) from e
5151

5252
def quote(self, s: str):
5353
return f"`{s}`"

0 commit comments

Comments
 (0)