Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 9798d79

Browse files
committed
Merge branch 'dbeatty10/dbt-profiles-dbt-project' of github.com:dbeatty10/data-diff into dbeatty10/dbt-profiles-dbt-project
2 parents 6df559f + 82c40f2 commit 9798d79

File tree

5 files changed

+287
-267
lines changed

5 files changed

+287
-267
lines changed

data_diff/dbt.py

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import rich
66
from dataclasses import dataclass
77
from packaging.version import parse as parse_version
8-
from typing import List, Optional, Dict
8+
from typing import List, Optional, Dict, Tuple
99
from pathlib import Path
1010

1111
import requests
@@ -29,15 +29,15 @@ def import_dbt():
2929
send_event_json,
3030
is_tracking_enabled,
3131
)
32-
from .utils import run_as_daemon, truncate_error
32+
from .utils import get_from_dict_with_raise, run_as_daemon, truncate_error
3333
from . import connect_to_table, diff_tables, Algorithm
3434

3535
RUN_RESULTS_PATH = "target/run_results.json"
3636
MANIFEST_PATH = "target/manifest.json"
3737
PROJECT_FILE = "dbt_project.yml"
3838
PROFILES_FILE = "profiles.yml"
3939
LOWER_DBT_V = "1.0.0"
40-
UPPER_DBT_V = "1.5.0"
40+
UPPER_DBT_V = "1.4.2"
4141

4242

4343
# https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L6
@@ -90,29 +90,19 @@ def dbt_diff(
9090

9191
if is_cloud and len(diff_vars.primary_keys) > 0:
9292
_cloud_diff(diff_vars)
93-
elif is_cloud:
94-
rich.print(
95-
"[red]"
96-
+ ".".join(diff_vars.prod_path)
97-
+ " <> "
98-
+ ".".join(diff_vars.dev_path)
99-
+ "[/] \n"
100-
+ "Skipped due to missing primary-key tag\n"
101-
)
102-
103-
if not is_cloud and len(diff_vars.primary_keys) == 1:
93+
elif not is_cloud and len(diff_vars.primary_keys) > 0:
10494
_local_diff(diff_vars)
105-
elif not is_cloud:
95+
else:
10696
rich.print(
10797
"[red]"
10898
+ ".".join(diff_vars.prod_path)
10999
+ " <> "
110100
+ ".".join(diff_vars.dev_path)
111101
+ "[/] \n"
112-
+ "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n"
102+
+ "Skipped due to missing primary-key tag(s).\n"
113103
)
114104

115-
rich.print("Diffs Complete!")
105+
rich.print("Diffs Complete!")
116106

117107

118108
def _get_diff_vars(
@@ -130,12 +120,12 @@ def _get_diff_vars(
130120
prod_schema = config_prod_schema if config_prod_schema else dev_schema
131121

132122
if dbt_parser.requires_upper:
133-
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.name]]
134-
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.name]]
123+
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
124+
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
135125
primary_keys = [x.upper() for x in primary_keys]
136126
else:
137-
dev_qualified_list = [dev_database, dev_schema, model.name]
138-
prod_qualified_list = [prod_database, prod_schema, model.name]
127+
dev_qualified_list = [dev_database, dev_schema, model.alias]
128+
prod_qualified_list = [prod_database, prod_schema, model.alias]
139129

140130
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
141131

@@ -144,10 +134,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
144134
column_diffs_str = ""
145135
dev_qualified_string = ".".join(diff_vars.dev_path)
146136
prod_qualified_string = ".".join(diff_vars.prod_path)
147-
primary_key = diff_vars.primary_keys[0]
148137

149-
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, primary_key)
150-
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, primary_key)
138+
table1 = connect_to_table(diff_vars.connection, dev_qualified_string, tuple(diff_vars.primary_keys))
139+
table2 = connect_to_table(diff_vars.connection, prod_qualified_string, tuple(diff_vars.primary_keys))
151140

152141
table1_columns = list(table1.get_schema())
153142
try:
@@ -176,7 +165,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
176165
if table2_set_diff:
177166
column_diffs_str += "Column(s) removed: " + str(table2_set_diff) + "\n"
178167

179-
mutual_set.discard(primary_key)
168+
mutual_set = mutual_set - set(diff_vars.primary_keys)
180169
extra_columns = tuple(mutual_set)
181170

182171
diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)
@@ -325,18 +314,40 @@ def set_project_dict(self):
325314
with open(self.project_dir / PROJECT_FILE) as project:
326315
self.project_dict = self.yaml.safe_load(project)
327316

328-
def set_connection(self):
329-
with open(self.profiles_dir / PROFILES_FILE) as profiles:
317+
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
318+
profiles_path = self.profiles_dir / PROFILES_FILE
319+
with open(profiles_path) as profiles:
330320
profiles = self.yaml.safe_load(profiles)
331321

332322
dbt_profile = self.project_dict.get("profile")
333-
profile_outputs = profiles.get(dbt_profile)
334-
profile_target = profile_outputs.get("target")
335-
credentials = profile_outputs.get("outputs").get(profile_target)
336-
conn_type = credentials.get("type").lower()
323+
324+
profile_outputs = get_from_dict_with_raise(
325+
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
326+
)
327+
profile_target = get_from_dict_with_raise(
328+
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
329+
)
330+
outputs = get_from_dict_with_raise(
331+
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
332+
)
333+
credentials = get_from_dict_with_raise(
334+
outputs,
335+
profile_target,
336+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
337+
)
338+
conn_type = get_from_dict_with_raise(
339+
credentials,
340+
"type",
341+
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
342+
)
343+
conn_type = conn_type.lower()
337344

338345
# values can contain env_vars
339346
rendered_credentials = self.ProfileRenderer().render_data(credentials)
347+
return rendered_credentials, conn_type
348+
349+
def set_connection(self):
350+
rendered_credentials, conn_type = self._get_connection_creds()
340351

341352
if conn_type == "snowflake":
342353
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:
@@ -363,6 +374,40 @@ def set_connection(self):
363374
"project": rendered_credentials.get("project"),
364375
"dataset": rendered_credentials.get("dataset"),
365376
}
377+
elif conn_type == "duckdb":
378+
conn_info = {
379+
"driver": conn_type,
380+
"filepath": rendered_credentials.get("path"),
381+
}
382+
elif conn_type == "redshift":
383+
if rendered_credentials.get("password") is None or rendered_credentials.get("method") == "iam":
384+
raise Exception("Only password authentication is currently supported for Redshift.")
385+
conn_info = {
386+
"driver": conn_type,
387+
"host": rendered_credentials.get("host"),
388+
"user": rendered_credentials.get("user"),
389+
"password": rendered_credentials.get("password"),
390+
"port": rendered_credentials.get("port"),
391+
"dbname": rendered_credentials.get("dbname"),
392+
}
393+
elif conn_type == "databricks":
394+
conn_info = {
395+
"driver": conn_type,
396+
"catalog": rendered_credentials.get("catalog"),
397+
"server_hostname": rendered_credentials.get("host"),
398+
"http_path": rendered_credentials.get("http_path"),
399+
"schema": rendered_credentials.get("schema"),
400+
"access_token": rendered_credentials.get("token"),
401+
}
402+
elif conn_type == "postgres":
403+
conn_info = {
404+
"driver": "postgresql",
405+
"host": rendered_credentials.get("host"),
406+
"user": rendered_credentials.get("user"),
407+
"password": rendered_credentials.get("password"),
408+
"port": rendered_credentials.get("port"),
409+
"dbname": rendered_credentials.get("dbname") or rendered_credentials.get("database"),
410+
}
366411
else:
367412
raise NotImplementedError(f"Provider {conn_type} is not yet supported for dbt diffs")
368413

data_diff/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re
3-
from typing import Iterable, Sequence
3+
from typing import Dict, Iterable, Sequence
44
from urllib.parse import urlparse
55
import operator
66
import threading
@@ -79,6 +79,13 @@ def truncate_error(error: str):
7979
return re.sub("'(.*?)'", "'***'", first_line)
8080

8181

82+
def get_from_dict_with_raise(dictionary: Dict, key: str, error_message: str):
83+
result = dictionary.get(key)
84+
if result is None:
85+
raise ValueError(error_message)
86+
return result
87+
88+
8289
class Vector(tuple):
8390

8491
"""Immutable implementation of a regular vector over any arithmetic value

0 commit comments

Comments
 (0)