Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c836415

Browse files
authored
Merge branch 'master' into dbeatty10/support-duckdb
2 parents d421923 + 6c30cdf commit c836415

File tree

12 files changed

+1000
-323
lines changed

12 files changed

+1000
-323
lines changed

data_diff/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

3-
from sqeleton.abcs import DbKey, DbTime, DbPath
3+
from sqeleton.abcs import DbTime, DbPath
44

55
from .tracking import disable_tracking
66
from .databases import connect
77
from .diff_tables import Algorithm
88
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
99
from .joindiff_tables import JoinDiffer, TABLE_WRITE_LIMIT
1010
from .table_segment import TableSegment
11-
from .utils import eval_name_template
11+
from .utils import eval_name_template, Vector
1212

1313

1414
def connect_to_table(
@@ -51,8 +51,8 @@ def diff_tables(
5151
# Extra columns to compare
5252
extra_columns: Tuple[str, ...] = None,
5353
# Start/end key_column values, used to restrict the segment
54-
min_key: DbKey = None,
55-
max_key: DbKey = None,
54+
min_key: Vector = None,
55+
max_key: Vector = None,
5656
# Start/end update_column values, used to restrict the segment
5757
min_update: DbTime = None,
5858
max_update: DbTime = None,
@@ -87,8 +87,8 @@ def diff_tables(
8787
update_column (str, optional): Name of updated column, which signals that rows changed.
8888
Usually updated_at or last_update. Used by `min_update` and `max_update`.
8989
extra_columns (Tuple[str, ...], optional): Extra columns to compare
90-
min_key (:data:`DbKey`, optional): Lowest key value, used to restrict the segment
91-
max_key (:data:`DbKey`, optional): Highest key value, used to restrict the segment
90+
min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
91+
max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
9292
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
9393
max_update (:data:`DbTime`, optional): Highest update_column value, used to restrict the segment
9494
threaded (bool): Enable/disable threaded diffing. Needed to take advantage of database threads.

data_diff/dbt.py

Lines changed: 107 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,23 @@
33
import os
44
import time
55
import rich
6-
import yaml
76
from dataclasses import dataclass
87
from packaging.version import parse as parse_version
9-
from typing import List, Optional, Dict
8+
from typing import List, Optional, Dict, Tuple
109

1110
import requests
12-
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
13-
from dbt.config.renderer import ProfileRenderer
11+
12+
13+
def import_dbt():
14+
try:
15+
from dbt_artifacts_parser.parser import parse_run_results, parse_manifest
16+
from dbt.config.renderer import ProfileRenderer
17+
import yaml
18+
except ImportError:
19+
raise RuntimeError("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'.")
20+
21+
return parse_run_results, parse_manifest, ProfileRenderer, yaml
22+
1423

1524
from .tracking import (
1625
set_entrypoint_name,
@@ -19,15 +28,15 @@
1928
send_event_json,
2029
is_tracking_enabled,
2130
)
22-
from .utils import run_as_daemon, truncate_error
31+
from .utils import get_from_dict_with_raise, run_as_daemon, truncate_error
2332
from . import connect_to_table, diff_tables, Algorithm
2433

2534
RUN_RESULTS_PATH = "/target/run_results.json"
2635
MANIFEST_PATH = "/target/manifest.json"
2736
PROJECT_FILE = "/dbt_project.yml"
2837
PROFILES_FILE = "/profiles.yml"
2938
LOWER_DBT_V = "1.0.0"
30-
UPPER_DBT_V = "1.5.0"
39+
UPPER_DBT_V = "1.4.2"
3140

3241

3342
@dataclass
@@ -54,37 +63,29 @@ def dbt_diff(
5463
if not is_cloud:
5564
dbt_parser.set_connection()
5665

57-
if config_prod_database is None or config_prod_schema is None:
58-
raise ValueError("Expected a value for prod_database: or prod_schema: under \nvars:\n data_diff: ")
66+
if config_prod_database is None:
67+
raise ValueError(
68+
"Expected a value for prod_database: OR prod_database: AND prod_schema: under \nvars:\n data_diff: "
69+
)
5970

6071
for model in models:
6172
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, datasource_id)
6273

6374
if is_cloud and len(diff_vars.primary_keys) > 0:
6475
_cloud_diff(diff_vars)
65-
elif is_cloud:
66-
rich.print(
67-
"[red]"
68-
+ ".".join(diff_vars.dev_path)
69-
+ " <> "
70-
+ ".".join(diff_vars.prod_path)
71-
+ "[/] \n"
72-
+ "Skipped due to missing primary-key tag\n"
73-
)
74-
75-
if not is_cloud and len(diff_vars.primary_keys) == 1:
76+
elif not is_cloud and len(diff_vars.primary_keys) > 0:
7677
_local_diff(diff_vars)
77-
elif not is_cloud:
78+
else:
7879
rich.print(
7980
"[red]"
80-
+ ".".join(diff_vars.dev_path)
81-
+ " <> "
8281
+ ".".join(diff_vars.prod_path)
82+
+ " <> "
83+
+ ".".join(diff_vars.dev_path)
8384
+ "[/] \n"
84-
+ "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n"
85+
+ "Skipped due to missing primary-key tag(s).\n"
8586
)
8687

87-
rich.print("Diffs Complete!")
88+
rich.print("Diffs Complete!")
8889

8990

9091
def _get_diff_vars(
@@ -102,12 +103,12 @@ def _get_diff_vars(
102103
prod_schema = config_prod_schema if config_prod_schema else dev_schema
103104

104105
if dbt_parser.requires_upper:
105-
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.name]]
106-
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.name]]
106+
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
107+
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
107108
primary_keys = [x.upper() for x in primary_keys]
108109
else:
109-
dev_qualified_list = [dev_database, dev_schema, model.name]
110-
prod_qualified_list = [prod_database, prod_schema, model.name]
110+
dev_qualified_list = [dev_database, dev_schema, model.alias]
111+
prod_qualified_list = [prod_database, prod_schema, model.alias]
111112

112113
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
113114

@@ -116,10 +117,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
116117
column_diffs_str = ""
117118
dev_qualified_string = ".".join(diff_vars.dev_path)
118119
prod_qualified_string = ".".join(diff_vars.prod_path)
119-
primary_key = diff_vars.primary_keys[0]
120120

121-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, primary_key)
122-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, primary_key)
121+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
122+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
123123

124124
table1_columns = list(table1.get_schema())
125125
try:
@@ -129,9 +129,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
129129
logging.info(ex)
130130
rich.print(
131131
"[red]"
132-
+ dev_qualified_string
133-
+ " <> "
134132
+ prod_qualified_string
133+
+ " <> "
134+
+ dev_qualified_string
135135
+ "[/] \n"
136136
+ column_diffs_str
137137
+ "[green]New model or no access to prod table.[/] \n"
@@ -143,22 +143,22 @@ def _local_diff(diff_vars: DiffVars) -> None:
143143
table2_set_diff = list(set(table2_columns) - set(table1_columns))
144144

145145
if table1_set_diff:
146-
column_diffs_str += "Columns exclusive to table A: " + str(table1_set_diff) + "\n"
146+
column_diffs_str += "Column(s) added: " + str(table1_set_diff) + "\n"
147147

148148
if table2_set_diff:
149-
column_diffs_str += "Columns exclusive to table B: " + str(table2_set_diff) + "\n"
149+
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
150150

151-
mutual_set.discard(primary_key)
151+
mutual_set = mutual_set - set(diff_vars.primary_keys)
152152
extra_columns = tuple(mutual_set)
153153

154154
diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)
155155

156156
if list(diff):
157157
rich.print(
158158
"[red]"
159-
+ dev_qualified_string
160-
+ " <> "
161159
+ prod_qualified_string
160+
+ " <> "
161+
+ dev_qualified_string
162162
+ "[/] \n"
163163
+ column_diffs_str
164164
+ diff.get_stats_string(is_dbt=True)
@@ -167,9 +167,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
167167
else:
168168
rich.print(
169169
"[red]"
170-
+ dev_qualified_string
171-
+ " <> "
172170
+ prod_qualified_string
171+
+ " <> "
172+
+ dev_qualified_string
173173
+ "[/] \n"
174174
+ column_diffs_str
175175
+ "[green]No row differences[/] \n"
@@ -191,8 +191,8 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
191191
payload = {
192192
"data_source1_id": diff_vars.datasource_id,
193193
"data_source2_id": diff_vars.datasource_id,
194-
"table1": diff_vars.dev_path,
195-
"table2": diff_vars.prod_path,
194+
"table1": diff_vars.prod_path,
195+
"table2": diff_vars.dev_path,
196196
"pk_columns": diff_vars.primary_keys,
197197
}
198198

@@ -216,9 +216,9 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
216216
diff_url = f"https://app.datafold.com/datadiffs/{diff_id}/overview"
217217
rich.print(
218218
"[red]"
219-
+ ".".join(diff_vars.dev_path)
220-
+ " <> "
221219
+ ".".join(diff_vars.prod_path)
220+
+ " <> "
221+
+ ".".join(diff_vars.dev_path)
222222
+ "[/] \n Diff in progress: \n "
223223
+ diff_url
224224
+ "\n"
@@ -261,13 +261,15 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
261261
self.project_dict = None
262262
self.requires_upper = False
263263

264+
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
265+
264266
def get_datadiff_variables(self) -> dict:
265267
return self.project_dict.get("vars").get("data_diff")
266268

267269
def get_models(self):
268270
with open(self.project_dir + RUN_RESULTS_PATH) as run_results:
269271
run_results_dict = json.load(run_results)
270-
run_results_obj = parse_run_results(run_results=run_results_dict)
272+
run_results_obj = self.parse_run_results(run_results=run_results_dict)
271273

272274
dbt_version = parse_version(run_results_obj.metadata.dbt_version)
273275

@@ -278,7 +280,7 @@ def get_models(self):
278280

279281
with open(self.project_dir + MANIFEST_PATH) as manifest:
280282
manifest_dict = json.load(manifest)
281-
manifest_obj = parse_manifest(manifest=manifest_dict)
283+
manifest_obj = self.parse_manifest(manifest=manifest_dict)
282284

283285
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
284286
models = [manifest_obj.nodes.get(x) for x in success_models]
@@ -293,20 +295,42 @@ def get_primary_keys(self, model):
293295

294296
def set_project_dict(self):
295297
with open(self.project_dir + PROJECT_FILE) as project:
296-
self.project_dict = yaml.safe_load(project)
298+
self.project_dict = self.yaml.safe_load(project)
297299

298-
def set_connection(self):
299-
with open(self.profiles_dir + PROFILES_FILE) as profiles:
300-
profiles = yaml.safe_load(profiles)
300+
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
301+
profiles_path = self.profiles_dir + PROFILES_FILE
302+
with open(profiles_path) as profiles:
303+
profiles = self.yaml.safe_load(profiles)
301304

302305
dbt_profile = self.project_dict.get("profile")
303-
profile_outputs = profiles.get(dbt_profile)
304-
profile_target = profile_outputs.get("target")
305-
credentials = profile_outputs.get("outputs").get(profile_target)
306-
conn_type = credentials.get("type").lower()
306+
307+
profile_outputs = get_from_dict_with_raise(
308+
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
309+
)
310+
profile_target = get_from_dict_with_raise(
311+
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
312+
)
313+
outputs = get_from_dict_with_raise(
314+
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
315+
)
316+
credentials = get_from_dict_with_raise(
317+
outputs,
318+
profile_target,
319+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
320+
)
321+
conn_type = get_from_dict_with_raise(
322+
credentials,
323+
"type",
324+
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
325+
)
326+
conn_type = conn_type.lower()
307327

308328
# values can contain env_vars
309-
rendered_credentials = ProfileRenderer().render_data(credentials)
329+
rendered_credentials = self.ProfileRenderer().render_data(credentials)
330+
return rendered_credentials, conn_type
331+
332+
def set_connection(self):
333+
rendered_credentials, conn_type = self._get_connection_creds()
310334

311335
if conn_type == "snowflake":
312336
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
@@ -337,6 +361,34 @@ def set_connection(self):
337361
conn_info = {
338362
"driver": conn_type,
339363
"filepath": rendered_credentials.get("path"),
364+
elif conn_type == "redshift":
365+
if rendered_credentials.get("password") is None or rendered_credentials.get("method") == "iam":
366+
raise Exception("Only password authentication is currently supported for Redshift.")
367+
conn_info = {
368+
"driver": conn_type,
369+
"host": rendered_credentials.get("host"),
370+
"user": rendered_credentials.get("user"),
371+
"password": rendered_credentials.get("password"),
372+
"port": rendered_credentials.get("port"),
373+
"dbname": rendered_credentials.get("dbname"),
374+
}
375+
elif conn_type == "databricks":
376+
conn_info = {
377+
"driver": conn_type,
378+
"catalog": rendered_credentials.get("catalog"),
379+
"server_hostname": rendered_credentials.get("host"),
380+
"http_path": rendered_credentials.get("http_path"),
381+
"schema": rendered_credentials.get("schema"),
382+
"access_token": rendered_credentials.get("token"),
383+
}
384+
elif conn_type == "postgres":
385+
conn_info = {
386+
"driver": "postgresql",
387+
"host": rendered_credentials.get("host"),
388+
"user": rendered_credentials.get("user"),
389+
"password": rendered_credentials.get("password"),
390+
"port": rendered_credentials.get("port"),
391+
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
340392
}
341393
else:
342394
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")

0 commit comments

Comments
 (0)