|
5 | 5 | import logging |
6 | 6 | from itertools import islice |
7 | 7 |
|
8 | | -from .utils import remove_password_from_url |
| 8 | +from .utils import remove_password_from_url, safezip, match_like |
9 | 9 |
|
10 | 10 | from .diff_tables import ( |
11 | 11 | TableSegment, |
12 | 12 | TableDiffer, |
13 | 13 | DEFAULT_BISECTION_THRESHOLD, |
14 | 14 | DEFAULT_BISECTION_FACTOR, |
| 15 | + create_schema, |
15 | 16 | ) |
16 | 17 | from .databases.connect import connect |
17 | 18 | from .parse_time import parse_time_before_now, UNITS_STR, ParseError |
@@ -39,6 +40,11 @@ def _remove_passwords_in_dict(d: dict): |
39 | 40 | d[k] = remove_password_from_url(v) |
40 | 41 |
|
41 | 42 |
|
| 43 | +def _get_schema(pair): |
| 44 | + db, table_path = pair |
| 45 | + return db.query_table_schema(table_path) |
| 46 | + |
| 47 | + |
42 | 48 | @click.command() |
43 | 49 | @click.argument("database1", required=False) |
44 | 50 | @click.argument("table1", required=False) |
@@ -67,7 +73,11 @@ def _remove_passwords_in_dict(d: dict): |
67 | 73 | @click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability") |
68 | 74 | @click.option("-v", "--verbose", is_flag=True, help="Print extra info") |
69 | 75 | @click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug") |
70 | | -@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.") |
| 76 | +@click.option( |
| 77 | + "--case-sensitive", |
| 78 | + is_flag=True, |
| 79 | + help="Column names are treated as case-sensitive. Otherwise, data-diff corrects their case according to schema.", |
| 80 | +) |
71 | 81 | @click.option( |
72 | 82 | "-j", |
73 | 83 | "--threads", |
@@ -111,7 +121,7 @@ def _main( |
111 | 121 | verbose, |
112 | 122 | interactive, |
113 | 123 | threads, |
114 | | - keep_column_case, |
| 124 | + case_sensitive, |
115 | 125 | json_output, |
116 | 126 | where, |
117 | 127 | threads1=None, |
@@ -158,35 +168,66 @@ def _main( |
158 | 168 |
|
159 | 169 | db1 = connect(database1, threads1 or threads) |
160 | 170 | db2 = connect(database2, threads2 or threads) |
| 171 | + dbs = db1, db2 |
161 | 172 |
|
162 | 173 | if interactive: |
163 | | - db1.enable_interactive() |
164 | | - db2.enable_interactive() |
| 174 | + for db in dbs: |
| 175 | + db.enable_interactive() |
165 | 176 |
|
166 | 177 | start = time.time() |
167 | 178 |
|
168 | 179 | try: |
169 | 180 | options = dict( |
170 | 181 | min_update=max_age and parse_time_before_now(max_age), |
171 | 182 | max_update=min_age and parse_time_before_now(min_age), |
172 | | - case_sensitive=keep_column_case, |
| 183 | + case_sensitive=case_sensitive, |
173 | 184 | where=where, |
174 | 185 | ) |
175 | 186 | except ParseError as e: |
176 | 187 | logging.error("Error while parsing age expression: %s" % e) |
177 | 188 | return |
178 | 189 |
|
179 | | - table1_seg = TableSegment(db1, db1.parse_table_name(table1), key_column, update_column, columns, **options) |
180 | | - table2_seg = TableSegment(db2, db2.parse_table_name(table2), key_column, update_column, columns, **options) |
181 | | - |
182 | 190 | differ = TableDiffer( |
183 | 191 | bisection_factor=bisection_factor, |
184 | 192 | bisection_threshold=bisection_threshold, |
185 | 193 | threaded=threaded, |
186 | 194 | max_threadpool_size=threads and threads * 2, |
187 | 195 | debug=debug, |
188 | 196 | ) |
189 | | - diff_iter = differ.diff_tables(table1_seg, table2_seg) |
| 197 | + |
| 198 | + table_names = table1, table2 |
| 199 | + table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] |
| 200 | + |
| 201 | + schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) |
| 202 | + schema1, schema2 = schemas = [ |
| 203 | + create_schema(db, table_path, schema, case_sensitive) |
| 204 | + for db, table_path, schema in safezip(dbs, table_paths, schemas) |
| 205 | + ] |
| 206 | + |
| 207 | + mutual = schema1.keys() & schema2.keys() # Case-aware, according to case_sensitive |
| 208 | + logging.debug(f"Available mutual columns: {mutual}") |
| 209 | + |
| 210 | + expanded_columns = set() |
| 211 | + for c in columns: |
| 212 | + match = set(match_like(c, mutual)) |
| 213 | + if not match: |
| 214 | + m1 = None if any(match_like(c, schema1.keys())) else f"{db1}/{table1}" |
| 215 | + m2 = None if any(match_like(c, schema2.keys())) else f"{db2}/{table2}" |
| 216 | + not_matched = ", ".join(m for m in [m1, m2] if m) |
| 217 | + raise ValueError(f"Column {c} not found in: {not_matched}") |
| 218 | + |
| 219 | + expanded_columns |= match |
| 220 | + |
| 221 | + columns = tuple(expanded_columns - {key_column, update_column}) |
| 222 | + |
| 223 | + logging.info(f"Diffing columns: key={key_column} update={update_column} extra={columns}") |
| 224 | + |
| 225 | + segments = [ |
| 226 | + TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) |
| 227 | + for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) |
| 228 | + ] |
| 229 | + |
| 230 | + diff_iter = differ.diff_tables(*segments) |
190 | 231 |
|
191 | 232 | if limit: |
192 | 233 | diff_iter = islice(diff_iter, int(limit)) |
|
0 commit comments