55import rich
66from dataclasses import dataclass
77from packaging .version import parse as parse_version
8- from typing import List , Optional , Dict
8+ from typing import List , Optional , Dict , Tuple
99
1010import requests
1111
@@ -28,7 +28,7 @@ def import_dbt():
2828 send_event_json ,
2929 is_tracking_enabled ,
3030)
31- from .utils import run_as_daemon , truncate_error
31+ from .utils import get_from_dict_with_raise , run_as_daemon , truncate_error
3232from . import connect_to_table , diff_tables , Algorithm
3333
3434RUN_RESULTS_PATH = "/target/run_results.json"
@@ -85,7 +85,7 @@ def dbt_diff(
8585 + "Skipped due to missing primary-key tag(s).\n "
8686 )
8787
88- rich .print ("Diffs Complete!" )
88+ rich .print ("Diffs Complete!" )
8989
9090
9191def _get_diff_vars (
@@ -103,12 +103,12 @@ def _get_diff_vars(
103103 prod_schema = config_prod_schema if config_prod_schema else dev_schema
104104
105105 if dbt_parser .requires_upper :
106- dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .name ]]
107- prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .name ]]
106+ dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .alias ]]
107+ prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .alias ]]
108108 primary_keys = [x .upper () for x in primary_keys ]
109109 else :
110- dev_qualified_list = [dev_database , dev_schema , model .name ]
111- prod_qualified_list = [prod_database , prod_schema , model .name ]
110+ dev_qualified_list = [dev_database , dev_schema , model .alias ]
111+ prod_qualified_list = [prod_database , prod_schema , model .alias ]
112112
113113 return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection )
114114
@@ -297,18 +297,40 @@ def set_project_dict(self):
297297 with open (self .project_dir + PROJECT_FILE ) as project :
298298 self .project_dict = self .yaml .safe_load (project )
299299
300- def set_connection (self ):
301- with open (self .profiles_dir + PROFILES_FILE ) as profiles :
300+ def _get_connection_creds (self ) -> Tuple [Dict [str , str ], str ]:
301+ profiles_path = self .profiles_dir + PROFILES_FILE
302+ with open (profiles_path ) as profiles :
302303 profiles = self .yaml .safe_load (profiles )
303304
304305 dbt_profile = self .project_dict .get ("profile" )
305- profile_outputs = profiles .get (dbt_profile )
306- profile_target = profile_outputs .get ("target" )
307- credentials = profile_outputs .get ("outputs" ).get (profile_target )
308- conn_type = credentials .get ("type" ).lower ()
306+
307+ profile_outputs = get_from_dict_with_raise (
308+ profiles , dbt_profile , f"No profile '{ dbt_profile } ' found in '{ profiles_path } '."
309+ )
310+ profile_target = get_from_dict_with_raise (
311+ profile_outputs , "target" , f"No target found in profile '{ dbt_profile } ' in '{ profiles_path } '."
312+ )
313+ outputs = get_from_dict_with_raise (
314+ profile_outputs , "outputs" , f"No outputs found in profile '{ dbt_profile } ' in '{ profiles_path } '."
315+ )
316+ credentials = get_from_dict_with_raise (
317+ outputs ,
318+ profile_target ,
319+ f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
320+ )
321+ conn_type = get_from_dict_with_raise (
322+ credentials ,
323+ "type" ,
324+ f"No type found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
325+ )
326+ conn_type = conn_type .lower ()
309327
310328 # values can contain env_vars
311329 rendered_credentials = self .ProfileRenderer ().render_data (credentials )
330+ return rendered_credentials , conn_type
331+
332+ def set_connection (self ):
333+ rendered_credentials , conn_type = self ._get_connection_creds ()
312334
313335 if conn_type == "snowflake" :
314336 if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("private_key_path" ) is not None :
0 commit comments