Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 78fc4ab

Browse files
authored
Merge branch 'master' into dbeatty10/dbt-profiles-dbt-project
2 parents 490886c + d2d7849 commit 78fc4ab

File tree

10 files changed

+723
-153
lines changed

10 files changed

+723
-153
lines changed

data_diff/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

3-
from sqeleton.abcs import DbKey, DbTime, DbPath
3+
from sqeleton.abcs import DbTime, DbPath
44

55
from .tracking import disable_tracking
66
from .databases import connect
77
from .diff_tables import Algorithm
88
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
99
from .joindiff_tables import JoinDiffer, TABLE_WRITE_LIMIT
1010
from .table_segment import TableSegment
11-
from .utils import eval_name_template
11+
from .utils import eval_name_template, Vector
1212

1313

1414
def connect_to_table(
@@ -51,8 +51,8 @@ def diff_tables(
5151
# Extra columns to compare
5252
extra_columns: Tuple[str, ...] = None,
5353
# Start/end key_column values, used to restrict the segment
54-
min_key: DbKey = None,
55-
max_key: DbKey = None,
54+
min_key: Vector = None,
55+
max_key: Vector = None,
5656
# Start/end update_column values, used to restrict the segment
5757
min_update: DbTime = None,
5858
max_update: DbTime = None,
@@ -87,8 +87,8 @@ def diff_tables(
8787
update_column (str, optional): Name of updated column, which signals that rows changed.
8888
Usually updated_at or last_update. Used by `min_update` and `max_update`.
8989
extra_columns (Tuple[str, ...], optional): Extra columns to compare
90-
min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment
91-
max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment
90+
min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
91+
max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
9292
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
9393
max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
9494
threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.

data_diff/dbt.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@
33
import os
44
import time
55
import rich
6-
import yaml
76
from dataclasses import dataclass
87
from packaging.version import parse as parse_version
98
from typing import List, Optional, Dict
109
from pathlib import Path
1110

1211
import requests
13-
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
14-
from dbt.config.renderer import ProfileRenderer
12+
13+
14+
def import_dbt():
15+
try:
16+
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
17+
from dbt.config.renderer import ProfileRenderer
18+
import yaml
19+
except ImportError:
20+
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")
21+
22+
return parse_run_results, parse_manifest, ProfileRenderer, yaml
23+
1524

1625
from .tracking import (
1726
set_entrypoint_name,
@@ -158,10 +167,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
158167
table2_set_diff = list(set(table2_columns) - set(table1_columns))
159168

160169
if table1_set_diff:
161-
column_diffs_str += "Columns exclusive to table A: " + str(table1_set_diff) + "\n"
170+
column_diffs_str += "Column(s) added: " + str(table1_set_diff) + "\n"
162171

163172
if table2_set_diff:
164-
column_diffs_str += "Columns exclusive to table B: " + str(table2_set_diff) + "\n"
173+
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
165174

166175
mutual_set.discard(primary_key)
167176
extra_columns = tuple(mutual_set)
@@ -274,13 +283,15 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
274283
self.project_dict = None
275284
self.requires_upper = False
276285

286+
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
287+
277288
def get_datadiff_variables(self) -> dict:
278289
return self.project_dict.get("vars").get("data_diff")
279290

280291
def get_models(self):
281292
with open(self.project_dir / RUN_RESULTS_PATH) as run_results:
282293
run_results_dict = json.load(run_results)
283-
run_results_obj = parse_run_results(run_results=run_results_dict)
294+
run_results_obj = self.parse_run_results(run_results=run_results_dict)
284295

285296
dbt_version = parse_version(run_results_obj.metadata.dbt_version)
286297

@@ -291,7 +302,7 @@ def get_models(self):
291302

292303
with open(self.project_dir / MANIFEST_PATH) as manifest:
293304
manifest_dict = json.load(manifest)
294-
manifest_obj = parse_manifest(manifest=manifest_dict)
305+
manifest_obj = self.parse_manifest(manifest=manifest_dict)
295306

296307
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
297308
models = [manifest_obj.nodes.get(x) for x in success_models]
@@ -306,11 +317,11 @@ def get_primary_keys(self, model):
306317

307318
def set_project_dict(self):
308319
with open(self.project_dir / PROJECT_FILE) as project:
309-
self.project_dict = yaml.safe_load(project)
320+
self.project_dict = self.yaml.safe_load(project)
310321

311322
def set_connection(self):
312323
with open(self.profiles_dir / PROFILES_FILE) as profiles:
313-
profiles = yaml.safe_load(profiles)
324+
profiles = self.yaml.safe_load(profiles)
314325

315326
dbt_profile = self.project_dict.get("profile")
316327
profile_outputs = profiles.get(dbt_profile)
@@ -319,7 +330,7 @@ def set_connection(self):
319330
conn_type = credentials.get("type").lower()
320331

321332
# values can contain env_vars
322-
rendered_credentials = ProfileRenderer().render_data(credentials)
333+
rendered_credentials = self.ProfileRenderer().render_data(credentials)
323334

324335
if conn_type == "snowflake":
325336
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:

data_diff/diff_tables.py

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from data_diff.info_tree import InfoTree, SegmentInfo
1616

17-
from .utils import run_as_daemon, safezip, getLogger, truncate_error
17+
from .utils import run_as_daemon, safezip, getLogger, truncate_error, Vector
1818
from .thread_utils import ThreadedYielder
19-
from .table_segment import TableSegment
19+
from .table_segment import TableSegment, create_mesh_from_points
2020
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
2121
from sqeleton.abcs import IKey
2222

@@ -135,7 +135,6 @@ def _get_stats(self, is_dbt: bool = False) -> DiffStats:
135135

136136
return DiffStats(diff_by_sign, table1_count, table2_count, unchanged, diff_percent, extra_column_diffs)
137137

138-
139138
def get_stats_string(self, is_dbt: bool = False):
140139
diff_stats = self._get_stats(is_dbt)
141140

@@ -271,63 +270,75 @@ def _diff_segments(
271270
):
272271
...
273272

274-
def _bisect_and_diff_tables(self, table1, table2, info_tree):
275-
if len(table1.key_columns) > 1:
276-
raise NotImplementedError("Composite key not supported yet!")
277-
if len(table2.key_columns) > 1:
278-
raise NotImplementedError("Composite key not supported yet!")
273+
def _bisect_and_diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree):
279274
if len(table1.key_columns) != len(table2.key_columns):
280275
raise ValueError("Tables should have an equivalent number of key columns!")
281-
(key1,) = table1.key_columns
282-
(key2,) = table2.key_columns
283-
284-
key_type = table1._schema[key1]
285-
key_type2 = table2._schema[key2]
286-
if not isinstance(key_type, IKey):
287-
raise NotImplementedError(f"Cannot use column of type {key_type} as a key")
288-
if not isinstance(key_type2, IKey):
289-
raise NotImplementedError(f"Cannot use column of type {key_type2} as a key")
290-
if key_type.python_type is not key_type2.python_type:
291-
raise TypeError(f"Incompatible key types: {key_type} and {key_type2}")
276+
277+
key_types1 = [table1._schema[i] for i in table1.key_columns]
278+
key_types2 = [table2._schema[i] for i in table2.key_columns]
279+
280+
for kt in key_types1 + key_types2:
281+
if not isinstance(kt, IKey):
282+
raise NotImplementedError(f"Cannot use a column of type {kt} as a key")
283+
284+
for kt1, kt2 in safezip(key_types1, key_types2):
285+
if kt1.python_type is not kt2.python_type:
286+
raise TypeError(f"Incompatible key types: {kt1} and {kt2}")
292287

293288
# Query min/max values
294289
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])
295290

296291
# Start with the first completed value, so we don't waste time waiting
297-
min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges))
292+
min_key1, max_key1 = self._parse_key_range_result(key_types1, next(key_ranges))
298293

299-
table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
294+
btable1, btable2 = [t.new_key_bounds(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
300295

301296
logger.info(
302-
f"Diffing segments at key-range: {table1.min_key}..{table2.max_key}. "
303-
f"size: table1 <= {table1.approximate_size()}, table2 <= {table2.approximate_size()}"
297+
f"Diffing segments at key-range: {btable1.min_key}..{btable2.max_key}. "
298+
f"size: table1 <= {btable1.approximate_size()}, table2 <= {btable2.approximate_size()}"
304299
)
305300

306301
ti = ThreadedYielder(self.max_threadpool_size)
307302
# Bisect (split) the table into segments, and diff them recursively.
308-
ti.submit(self._bisect_and_diff_segments, ti, table1, table2, info_tree)
303+
ti.submit(self._bisect_and_diff_segments, ti, btable1, btable2, info_tree)
309304

310305
# Now we check for the second min-max, to diff the portions we "missed".
311-
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))
306+
# This is achieved by subtracting the table ranges, and dividing the resulting space into aligned boxes.
307+
# For example, given tables A & B, and a 2D compound key, where A was queried first for key-range,
308+
# the regions of B we need to diff in this second pass are marked by B1..8:
309+
# ┌──┬──────┬──┐
310+
# │B1│ B2 │B3│
311+
# ├──┼──────┼──┤
312+
# │B4│ A │B5│
313+
# ├──┼──────┼──┤
314+
# │B6│ B7 │B8│
315+
# └──┴──────┴──┘
316+
# Overall, the max number of new regions in this 2nd pass is 3^|k| - 1
312317

313-
if min_key2 < min_key1:
314-
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
315-
ti.submit(self._bisect_and_diff_segments, ti, *pre_tables, info_tree)
318+
min_key2, max_key2 = self._parse_key_range_result(key_types1, next(key_ranges))
316319

317-
if max_key2 > max_key1:
318-
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
319-
ti.submit(self._bisect_and_diff_segments, ti, *post_tables, info_tree)
320+
points = [list(sorted(p)) for p in safezip(min_key1, min_key2, max_key1, max_key2)]
321+
box_mesh = create_mesh_from_points(*points)
322+
323+
new_regions = [(p1, p2) for p1, p2 in box_mesh if p1 < p2 and not (p1 >= min_key1 and p2 <= max_key1)]
324+
325+
for p1, p2 in new_regions:
326+
extra_tables = [t.new_key_bounds(min_key=p1, max_key=p2) for t in (table1, table2)]
327+
ti.submit(self._bisect_and_diff_segments, ti, *extra_tables, info_tree)
320328

321329
return ti
322330

323-
def _parse_key_range_result(self, key_type, key_range):
324-
mn, mx = key_range
325-
cls = key_type.make_value
331+
def _parse_key_range_result(self, key_types, key_range) -> Tuple[Vector, Vector]:
332+
min_key_values, max_key_values = key_range
333+
326334
# We add 1 because our ranges are exclusive of the end (like in Python)
327335
try:
328-
return cls(mn), cls(mx) + 1
336+
min_key = Vector(key_type.make_value(mn) for key_type, mn in safezip(key_types, min_key_values))
337+
max_key = Vector(key_type.make_value(mx) + 1 for key_type, mx in safezip(key_types, max_key_values))
329338
except (TypeError, ValueError) as e:
330-
raise type(e)(f"Cannot apply {key_type} to '{mn}', '{mx}'.") from e
339+
raise type(e)(f"Cannot apply {key_types} to '{min_key_values}', '{max_key_values}'.") from e
340+
341+
return min_key, max_key
331342

332343
def _bisect_and_diff_segments(
333344
self,

0 commit comments

Comments
 (0)