55import rich
66from dataclasses import dataclass
77from packaging .version import parse as parse_version
8- from typing import List , Optional , Dict
8+ from typing import List , Optional , Dict , Tuple
99from pathlib import Path
1010
1111import requests
@@ -29,15 +29,15 @@ def import_dbt():
2929 send_event_json ,
3030 is_tracking_enabled ,
3131)
32- from .utils import run_as_daemon , truncate_error
32+ from .utils import get_from_dict_with_raise , run_as_daemon , truncate_error
3333from . import connect_to_table , diff_tables , Algorithm
3434
3535RUN_RESULTS_PATH = "target/run_results.json"
3636MANIFEST_PATH = "target/manifest.json"
3737PROJECT_FILE = "dbt_project.yml"
3838PROFILES_FILE = "profiles.yml"
3939LOWER_DBT_V = "1.0.0"
40- UPPER_DBT_V = "1.5.0 "
40+ UPPER_DBT_V = "1.4.2 "
4141
4242
4343# https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L6
@@ -90,29 +90,19 @@ def dbt_diff(
9090
9191 if is_cloud and len (diff_vars .primary_keys ) > 0 :
9292 _cloud_diff (diff_vars )
93- elif is_cloud :
94- rich .print (
95- "[red]"
96- + "." .join (diff_vars .prod_path )
97- + " <> "
98- + "." .join (diff_vars .dev_path )
99- + "[/] \n "
100- + "Skipped due to missing primary-key tag\n "
101- )
102-
103- if not is_cloud and len (diff_vars .primary_keys ) == 1 :
93+ elif not is_cloud and len (diff_vars .primary_keys ) > 0 :
10494 _local_diff (diff_vars )
105- elif not is_cloud :
95+ else :
10696 rich .print (
10797 "[red]"
10898 + "." .join (diff_vars .prod_path )
10999 + " <> "
110100 + "." .join (diff_vars .dev_path )
111101 + "[/] \n "
112- + "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs) \n "
102+ + "Skipped due to missing primary-key tag(s). \n "
113103 )
114104
115- rich .print ("Diffs Complete!" )
105+ rich .print ("Diffs Complete!" )
116106
117107
118108def _get_diff_vars (
@@ -130,12 +120,12 @@ def _get_diff_vars(
130120 prod_schema = config_prod_schema if config_prod_schema else dev_schema
131121
132122 if dbt_parser .requires_upper :
133- dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .name ]]
134- prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .name ]]
123+ dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .alias ]]
124+ prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .alias ]]
135125 primary_keys = [x .upper () for x in primary_keys ]
136126 else :
137- dev_qualified_list = [dev_database , dev_schema , model .name ]
138- prod_qualified_list = [prod_database , prod_schema , model .name ]
127+ dev_qualified_list = [dev_database , dev_schema , model .alias ]
128+ prod_qualified_list = [prod_database , prod_schema , model .alias ]
139129
140130 return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection )
141131
@@ -144,10 +134,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
144134 column_diffs_str = ""
145135 dev_qualified_string = "." .join (diff_vars .dev_path )
146136 prod_qualified_string = "." .join (diff_vars .prod_path )
147- primary_key = diff_vars .primary_keys [0 ]
148137
149- table1 = connect_to_table (diff_vars .connection , dev_qualified_string , primary_key )
150- table2 = connect_to_table (diff_vars .connection , prod_qualified_string , primary_key )
138+ table1 = connect_to_table (diff_vars .connection , dev_qualified_string , tuple ( diff_vars . primary_keys ) )
139+ table2 = connect_to_table (diff_vars .connection , prod_qualified_string , tuple ( diff_vars . primary_keys ) )
151140
152141 table1_columns = list (table1 .get_schema ())
153142 try :
@@ -176,7 +165,7 @@ def _local_diff(diff_vars: DiffVars) -> None:
176165 if table2_set_diff :
177166 column_diffs_str += "Column(s) removed: " + str (table2_set_diff ) + "\n "
178167
179- mutual_set . discard ( primary_key )
168+ mutual_set = mutual_set - set ( diff_vars . primary_keys )
180169 extra_columns = tuple (mutual_set )
181170
182171 diff = diff_tables (table1 , table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = extra_columns )
@@ -325,18 +314,40 @@ def set_project_dict(self):
325314 with open (self .project_dir / PROJECT_FILE ) as project :
326315 self .project_dict = self .yaml .safe_load (project )
327316
328- def set_connection (self ):
329- with open (self .profiles_dir / PROFILES_FILE ) as profiles :
317+ def _get_connection_creds (self ) -> Tuple [Dict [str , str ], str ]:
318+ profiles_path = self .profiles_dir / PROFILES_FILE
319+ with open (profiles_path ) as profiles :
330320 profiles = self .yaml .safe_load (profiles )
331321
332322 dbt_profile = self .project_dict .get ("profile" )
333- profile_outputs = profiles .get (dbt_profile )
334- profile_target = profile_outputs .get ("target" )
335- credentials = profile_outputs .get ("outputs" ).get (profile_target )
336- conn_type = credentials .get ("type" ).lower ()
323+
324+ profile_outputs = get_from_dict_with_raise (
325+ profiles , dbt_profile , f"No profile '{ dbt_profile } ' found in '{ profiles_path } '."
326+ )
327+ profile_target = get_from_dict_with_raise (
328+ profile_outputs , "target" , f"No target found in profile '{ dbt_profile } ' in '{ profiles_path } '."
329+ )
330+ outputs = get_from_dict_with_raise (
331+ profile_outputs , "outputs" , f"No outputs found in profile '{ dbt_profile } ' in '{ profiles_path } '."
332+ )
333+ credentials = get_from_dict_with_raise (
334+ outputs ,
335+ profile_target ,
336+ f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
337+ )
338+ conn_type = get_from_dict_with_raise (
339+ credentials ,
340+ "type" ,
341+ f"No type found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
342+ )
343+ conn_type = conn_type .lower ()
337344
338345 # values can contain env_vars
339346 rendered_credentials = self .ProfileRenderer ().render_data (credentials )
347+ return rendered_credentials , conn_type
348+
349+ def set_connection (self ):
350+ rendered_credentials , conn_type = self ._get_connection_creds ()
340351
341352 if conn_type == "snowflake" :
342353 if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("private_key_path" ) is not None :
@@ -363,6 +374,40 @@ def set_connection(self):
363374 "project" : rendered_credentials .get ("project" ),
364375 "dataset" : rendered_credentials .get ("dataset" ),
365376 }
377+ elif conn_type == "duckdb" :
378+ conn_info = {
379+ "driver" : conn_type ,
380+ "filepath" : rendered_credentials .get ("path" ),
381+ }
382+ elif conn_type == "redshift" :
383+ if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("method" ) == "iam" :
384+ raise Exception ("Only password authentication is currently supported for Redshift." )
385+ conn_info = {
386+ "driver" : conn_type ,
387+ "host" : rendered_credentials .get ("host" ),
388+ "user" : rendered_credentials .get ("user" ),
389+ "password" : rendered_credentials .get ("password" ),
390+ "port" : rendered_credentials .get ("port" ),
391+ "dbname" : rendered_credentials .get ("dbname" ),
392+ }
393+ elif conn_type == "databricks" :
394+ conn_info = {
395+ "driver" : conn_type ,
396+ "catalog" : rendered_credentials .get ("catalog" ),
397+ "server_hostname" : rendered_credentials .get ("host" ),
398+ "http_path" : rendered_credentials .get ("http_path" ),
399+ "schema" : rendered_credentials .get ("schema" ),
400+ "access_token" : rendered_credentials .get ("token" ),
401+ }
402+ elif conn_type == "postgres" :
403+ conn_info = {
404+ "driver" : "postgresql" ,
405+ "host" : rendered_credentials .get ("host" ),
406+ "user" : rendered_credentials .get ("user" ),
407+ "password" : rendered_credentials .get ("password" ),
408+ "port" : rendered_credentials .get ("port" ),
409+ "dbname" : rendered_credentials .get ("dbname" ) or rendered_credentials .get ("database" ),
410+ }
366411 else :
367412 raise NotImplementedError (f"Provider { conn_type } is not yet supported for dbt diffs" )
368413
0 commit comments