Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 1f9903a

Browse files
authored
Merge branch 'master' into issue_427
2 parents b49b93a + 48d918c commit 1f9903a

File tree

3 files changed

+176
-140
lines changed

3 files changed

+176
-140
lines changed

data_diff/dbt.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import rich
66
from dataclasses import dataclass
77
from packaging.version import parse as parse_version
8-
from typing import List, Optional, Dict
8+
from typing import List, Optional, Dict, Tuple
99

1010
import requests
1111

@@ -28,7 +28,7 @@ def import_dbt():
2828
send_event_json,
2929
is_tracking_enabled,
3030
)
31-
from .utils import run_as_daemon, truncate_error
31+
from .utils import get_from_dict_with_raise, run_as_daemon, truncate_error
3232
from . import connect_to_table, diff_tables, Algorithm
3333

3434
RUN_RESULTS_PATH = "/target/run_results.json"
@@ -85,7 +85,7 @@ def dbt_diff(
8585
+ "Skipped due to missing primary-key tag\n"
8686
)
8787

88-
rich.print("Diffs Complete!")
88+
rich.print("Diffs Complete!")
8989

9090

9191
def _get_diff_vars(
@@ -103,12 +103,12 @@ def _get_diff_vars(
103103
prod_schema = config_prod_schema if config_prod_schema else dev_schema
104104

105105
if dbt_parser.requires_upper:
106-
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.name]]
107-
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.name]]
106+
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
107+
prod_qualified_list = [x.upper() for x in [prod_database, prod_schema, model.alias]]
108108
primary_keys = [x.upper() for x in primary_keys]
109109
else:
110-
dev_qualified_list = [dev_database, dev_schema, model.name]
111-
prod_qualified_list = [prod_database, prod_schema, model.name]
110+
dev_qualified_list = [dev_database, dev_schema, model.alias]
111+
prod_qualified_list = [prod_database, prod_schema, model.alias]
112112

113113
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection)
114114

@@ -297,18 +297,40 @@ def set_project_dict(self):
297297
with open(self.project_dir + PROJECT_FILE) as project:
298298
self.project_dict = self.yaml.safe_load(project)
299299

300-
def set_connection(self):
301-
with open(self.profiles_dir + PROFILES_FILE) as profiles:
300+
def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:
301+
profiles_path = self.profiles_dir + PROFILES_FILE
302+
with open(profiles_path) as profiles:
302303
profiles = self.yaml.safe_load(profiles)
303304

304305
dbt_profile = self.project_dict.get("profile")
305-
profile_outputs = profiles.get(dbt_profile)
306-
profile_target = profile_outputs.get("target")
307-
credentials = profile_outputs.get("outputs").get(profile_target)
308-
conn_type = credentials.get("type").lower()
306+
307+
profile_outputs = get_from_dict_with_raise(
308+
profiles, dbt_profile, f"No profile '{dbt_profile}' found in '{profiles_path}'."
309+
)
310+
profile_target = get_from_dict_with_raise(
311+
profile_outputs, "target", f"No target found in profile '{dbt_profile}' in '{profiles_path}'."
312+
)
313+
outputs = get_from_dict_with_raise(
314+
profile_outputs, "outputs", f"No outputs found in profile '{dbt_profile}' in '{profiles_path}'."
315+
)
316+
credentials = get_from_dict_with_raise(
317+
outputs,
318+
profile_target,
319+
f"No credentials found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
320+
)
321+
conn_type = get_from_dict_with_raise(
322+
credentials,
323+
"type",
324+
f"No type found for target '{profile_target}' in profile '{dbt_profile}' in '{profiles_path}'.",
325+
)
326+
conn_type = conn_type.lower()
309327

310328
# values can contain env_vars
311329
rendered_credentials = self.ProfileRenderer().render_data(credentials)
330+
return rendered_credentials, conn_type
331+
332+
def set_connection(self):
333+
rendered_credentials, conn_type = self._get_connection_creds()
312334

313335
if conn_type == "snowflake":
314336
if rendered_credentials.get("password") is None or rendered_credentials.get("private_key_path") is not None:

data_diff/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re
3-
from typing import Iterable, Sequence
3+
from typing import Dict, Iterable, Sequence
44
from urllib.parse import urlparse
55
import operator
66
import threading
@@ -79,6 +79,13 @@ def truncate_error(error: str):
7979
return re.sub("'(.*?)'", "'***'", first_line)
8080

8181

82+
def get_from_dict_with_raise(dictionary: Dict, key: str, error_message: str):
83+
result = dictionary.get(key)
84+
if result is None:
85+
raise ValueError(error_message)
86+
return result
87+
88+
8289
class Vector(tuple):
8390

8491
"""Immutable implementation of a regular vector over any arithmetic value

0 commit comments

Comments
 (0)