Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit f4af2cd

Browse files
committed
Merge branch 'master' into addl_db_support
2 parents ab57394 + fbcb805 commit f4af2cd

File tree

13 files changed

+1078
-342
lines changed

13 files changed

+1078
-342
lines changed

data_diff/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

3-
from sqeleton.abcs import DbKey, DbTime, DbPath
3+
from sqeleton.abcs import DbTime, DbPath
44

55
from .tracking import disable_tracking
66
from .databases import connect
77
from .diff_tables import Algorithm
88
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
99
from .joindiff_tables import JoinDiffer, TABLE_WRITE_LIMIT
1010
from .table_segment import TableSegment
11-
from .utils import eval_name_template
11+
from .utils import eval_name_template, Vector
1212

1313

1414
def connect_to_table(
@@ -51,8 +51,8 @@ def diff_tables(
5151
# Extra columns to compare
5252
extra_columns: Tuple[str, ...] = None,
5353
# Start/end key_column values, used to restrict the segment
54-
min_key: DbKey = None,
55-
max_key: DbKey = None,
54+
min_key: Vector = None,
55+
max_key: Vector = None,
5656
# Start/end update_column values, used to restrict the segment
5757
min_update: DbTime = None,
5858
max_update: DbTime = None,
@@ -87,8 +87,8 @@ def diff_tables(
8787
update_column (str, optional): Name of updated column, which signals that rows changed.
8888
Usually updated_at or last_update. Used by `min_update` and `max_update`.
8989
extra_columns (Tuple[str, ...], optional): Extra columns to compare
90-
min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment
91-
max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment
90+
min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
91+
max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
9292
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
9393
max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
9494
threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.

data_diff/dbt.py

Lines changed: 128 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
import json
22
import logging
33
import os
4+
import time
45
import rich
5-
import yaml
66
from dataclasses import dataclass
77
from packaging.version import parse as parse_version
8-
from typing import List, Optional, Dict
8+
from typing import List, Optional, Dict, Tuple
99

1010
import requests
11-
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
12-
from dbt.config.renderer import ProfileRenderer
1311

12+
13+
def import_dbt():
14+
try:
15+
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
16+
from dbt.config.renderer import ProfileRenderer
17+
import yaml
18+
except ImportError:
19+
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")
20+
21+
return parse_run_results, parse_manifest, ProfileRenderer, yaml
22+
23+
24+
from .tracking import (
25+
set_entrypoint_name,
26+
create_end_event_json,
27+
create_start_event_json,
28+
send_event_json,
29+
is_tracking_enabled,
30+
)
31+
from .utils import get_from_dict_with_raise, run_as_daemon, truncate_error
1432
from . import connect_to_table, diff_tables, Algorithm
1533

1634
RUN_RESULTS_PATH = "/target/run_results.json"
@@ -33,6 +51,7 @@ class DiffVars:
3351
def dbt_diff(
3452
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
3553
) -> None:
54+
set_entrypoint_name("CLI-dbt")
3655
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
3756
models = dbt_parser.get_models()
3857
dbt_parser.set_project_dict()
@@ -44,8 +63,10 @@ def dbt_diff(
4463
if not is_cloud:
4564
dbt_parser.set_connection()
4665

47-
if config_prod_database is None or config_prod_schema is None:
48-
raise ValueError("Expected a value for prod_database: or prod_schema: under \nvars:\n data_diff: ")
66+
if config_prod_database is None:
67+
raise ValueError(
68+
"Expected a value for prod_database: OR prod_database: AND prod_schema: under \nvars:\n data_diff: "
69+
)
4970

5071
for model in models:
5172
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id)
@@ -55,9 +76,9 @@ def dbt_diff(
5576
elif is_cloud:
5677
rich.print(
5778
"[red]"
58-
+ ".".join(diff_vars.dev_path)
59-
+ " <> "
6079
+ ".".join(diff_vars.prod_path)
80+
+ " <> "
81+
+ ".".join(diff_vars.dev_path)
6182
+ "[/] \n"
6283
+ "Skipped due to missing primary-key tag\n"
6384
)
@@ -67,14 +88,14 @@ def dbt_diff(
6788
elif not is_cloud:
6889
rich.print(
6990
"[red]"
70-
+ ".".join(diff_vars.dev_path)
71-
+ " <> "
7291
+ ".".join(diff_vars.prod_path)
92+
+ " <> "
93+
+ ".".join(diff_vars.dev_path)
7394
+ "[/] \n"
7495
+ "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n"
7596
)
7697

77-
rich.print("Diffs Complete!")
98+
rich.print("Diffs Complete!")
7899

79100

80101
def _get_diff_vars(
@@ -92,12 +113,12 @@ def _get_diff_vars(
92113
prod_schema = config_prod_schema if config_prod_schema else dev_schema
93114

94115
if dbt_parser.requires_upper:
95-
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.name]]
96-
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.name]]
116+
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
117+
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
97118
primary_keys = [x.upper() for x in primary_keys]
98119
else:
99-
dev_qualified_list = [dev_database, dev_schema, model.name]
100-
prod_qualified_list = [prod_database, prod_schema, model.name]
120+
dev_qualified_list = [dev_database, dev_schema, model.alias]
121+
prod_qualified_list = [prod_database, prod_schema, model.alias]
101122

102123
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
103124

@@ -119,9 +140,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
119140
logging.info(ex)
120141
rich.print(
121142
"[red]"
122-
+ dev_qualified_string
123-
+ " <> "
124143
+ prod_qualified_string
144+
+ " <> "
145+
+ dev_qualified_string
125146
+ "[/] \n"
126147
+ column_diffs_str
127148
+ "[green]New model or no access to prod table.[/] \n"
@@ -133,10 +154,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
133154
table2_set_diff = list(set(table2_columns) - set(table1_columns))
134155

135156
if table1_set_diff:
136-
column_diffs_str += "Columns exclusive to table A: " + str(table1_set_diff) + "\n"
157+
column_diffs_str += "Column(s) added: " + str(table1_set_diff) + "\n"
137158

138159
if table2_set_diff:
139-
column_diffs_str += "Columns exclusive to table B: " + str(table2_set_diff) + "\n"
160+
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
140161

141162
mutual_set.discard(primary_key)
142163
extra_columns = tuple(mutual_set)
@@ -146,20 +167,20 @@ def _local_diff(diff_vars: DiffVars) -> None:
146167
if list(diff):
147168
rich.print(
148169
"[red]"
149-
+ dev_qualified_string
150-
+ " <> "
151170
+ prod_qualified_string
171+
+ " <> "
172+
+ dev_qualified_string
152173
+ "[/] \n"
153174
+ column_diffs_str
154-
+ diff.get_stats_string()
175+
+ diff.get_stats_string(is_dbt=True)
155176
+ "\n"
156177
)
157178
else:
158179
rich.print(
159180
"[red]"
160-
+ dev_qualified_string
161-
+ " <> "
162181
+ prod_qualified_string
182+
+ " <> "
183+
+ dev_qualified_string
163184
+ "[/] \n"
164185
+ column_diffs_str
165186
+ "[green]No row differences[/] \n"
@@ -181,31 +202,62 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
181202
payload = {
182203
"data_source1_id": diff_vars.datasource_id,
183204
"data_source2_id": diff_vars.datasource_id,
184-
"table1": diff_vars.dev_path,
185-
"table2": diff_vars.prod_path,
205+
"table1": diff_vars.prod_path,
206+
"table2": diff_vars.dev_path,
186207
"pk_columns": diff_vars.primary_keys,
187208
}
188209

189210
headers = {
190211
"Authorization": f"Key {api_key}",
191212
"Content-Type": "application/json",
192213
}
214+
if is_tracking_enabled():
215+
event_json = create_start_event_json({"is_cloud": True, "datasource_id": diff_vars.datasource_id})
216+
run_as_daemon(send_event_json, event_json)
217+
218+
start = time.monotonic()
219+
error = None
220+
diff_id = None
221+
try:
222+
response = requests.request("POST", url, headers=headers, json=payload, timeout=30)
223+
response.raise_for_status()
224+
data = response.json()
225+
diff_id = data["id"]
226+
# TODO in future we should support self hosted datafold
227+
diff_url = f"https://app.datafold.com/datadiffs/{diff_id}/overview"
228+
rich.print(
229+
"[red]"
230+
+ ".".join(diff_vars.prod_path)
231+
+ " <> "
232+
+ ".".join(diff_vars.dev_path)
233+
+ "[/] \n Diff in progress: \n "
234+
+ diff_url
235+
+ "\n"
236+
)
237+
except BaseException as ex: # Catch KeyboardInterrupt too
238+
error = ex
239+
finally:
240+
# we don't currently have much of this information
241+
# but I imagine a future iteration of this _cloud method
242+
# will poll for results
243+
if is_tracking_enabled():
244+
err_message = truncate_error(repr(error))
245+
event_json = create_end_event_json(
246+
is_success=error is None,
247+
runtime_seconds=time.monotonic() - start,
248+
data_source_1_type="",
249+
data_source_2_type="",
250+
table1_count=0,
251+
table2_count=0,
252+
diff_count=0,
253+
error=err_message,
254+
diff_id=diff_id,
255+
is_cloud=True,
256+
)
257+
send_event_json(event_json)
193258

194-
response = requests.request("POST", url, headers=headers, json=payload, timeout=30)
195-
response.raise_for_status()
196-
data = response.json()
197-
diff_id = data["id"]
198-
# TODO in future we should support self hosted datafold
199-
diff_url = f"https://app.datafold.com/datadiffs/{diff_id}/overview"
200-
rich.print(
201-
"[red]"
202-
+ ".".join(diff_vars.dev_path)
203-
+ " <> "
204-
+ ".".join(diff_vars.prod_path)
205-
+ "[/] \n Diff in progress: \n "
206-
+ diff_url
207-
+ "\n"
208-
)
259+
if error:
260+
raise error
209261

210262

211263
class DbtParser:
@@ -220,25 +272,26 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
220272
self.project_dict = None
221273
self.requires_upper = False
222274

275+
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
276+
223277
def get_datadiff_variables(self) -> dict:
224278
return self.project_dict.get("vars").get("data_diff")
225279

226280
def get_models(self):
227281
with open(self.project_dir + RUN_RESULTS_PATH) as run_results:
228282
run_results_dict = json.load(run_results)
229-
run_results_obj = parse_run_results(run_results=run_results_dict)
283+
run_results_obj = self.parse_run_results(run_results=run_results_dict)
230284

231285
dbt_version = parse_version(run_results_obj.metadata.dbt_version)
232286

233-
# TODO 1.4 support
234287
if dbt_version < parse_version(LOWER_DBT_V) or dbt_version >= parse_version(UPPER_DBT_V):
235288
raise Exception(
236289
f"Found dbt: v{dbt_version} Expected the dbt project's version to be >= {LOWER_DBT_V} and < {UPPER_DBT_V}"
237290
)
238291

239292
with open(self.project_dir + MANIFEST_PATH) as manifest:
240293
manifest_dict = json.load(manifest)
241-
manifest_obj = parse_manifest(manifest=manifest_dict)
294+
manifest_obj = self.parse_manifest(manifest=manifest_dict)
242295

243296
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
244297
models = [manifest_obj.nodes.get(x) for x in success_models]
@@ -253,20 +306,42 @@ def get_primary_keys(self, model):
253306

254307
def set_project_dict(self):
255308
with open(self.project_dir + PROJECT_FILE) as project:
256-
self.project_dict = yaml.safe_load(project)
309+
self.project_dict = self.yaml.safe_load(project)
257310

258-
def set_connection(self):
259-
with open(self.profiles_dir + PROFILES_FILE) as profiles:
260-
profiles = yaml.safe_load(profiles)
311+
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
312+
profiles_path = self.profiles_dir + PROFILES_FILE
313+
with open(profiles_path) as profiles:
314+
profiles = self.yaml.safe_load(profiles)
261315

262316
dbt_profile = self.project_dict.get("profile")
263-
profile_outputs = profiles.get(dbt_profile)
264-
profile_target = profile_outputs.get("target")
265-
credentials = profile_outputs.get("outputs").get(profile_target)
266-
conn_type = credentials.get("type").lower()
317+
318+
profile_outputs = get_from_dict_with_raise(
319+
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
320+
)
321+
profile_target = get_from_dict_with_raise(
322+
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
323+
)
324+
outputs = get_from_dict_with_raise(
325+
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
326+
)
327+
credentials = get_from_dict_with_raise(
328+
outputs,
329+
profile_target,
330+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
331+
)
332+
conn_type = get_from_dict_with_raise(
333+
credentials,
334+
"type",
335+
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
336+
)
337+
conn_type = conn_type.lower()
267338

268339
# values can contain env_vars
269-
rendered_credentials = ProfileRenderer().render_data(credentials)
340+
rendered_credentials = self.ProfileRenderer().render_data(credentials)
341+
return rendered_credentials, conn_type
342+
343+
def set_connection(self):
344+
rendered_credentials, conn_type = self._get_connection_creds()
270345

271346
# this whole block should be refactored/extracted to method(s)
272347
if conn_type == "snowflake":

0 commit comments

Comments
 (0)