33import os
44import time
55import rich
6- import yaml
76from dataclasses import dataclass
87from packaging .version import parse as parse_version
9- from typing import List , Optional , Dict
8+ from typing import List , Optional , Dict , Tuple
109
1110import requests
12- from dbt_artifacts_parser .parser import parse_run_results , parse_manifest
13- from dbt .config .renderer import ProfileRenderer
11+
12+
13+ def import_dbt ():
14+ try :
15+ from dbt_artifacts_parser .parser import parse_run_results , parse_manifest
16+ from dbt .config .renderer import ProfileRenderer
17+ import yaml
18+ except ImportError :
19+ raise RuntimeError ("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'." )
20+
21+ return parse_run_results , parse_manifest , ProfileRenderer , yaml
22+
1423
1524from .tracking import (
1625 set_entrypoint_name ,
1928 send_event_json ,
2029 is_tracking_enabled ,
2130)
22- from .utils import run_as_daemon , truncate_error
31+ from .utils import get_from_dict_with_raise , run_as_daemon , truncate_error
2332from . import connect_to_table , diff_tables , Algorithm
2433
2534RUN_RESULTS_PATH = "/target/run_results.json"
2635MANIFEST_PATH = "/target/manifest.json"
2736PROJECT_FILE = "/dbt_project.yml"
2837PROFILES_FILE = "/profiles.yml"
2938LOWER_DBT_V = "1.0.0"
30- UPPER_DBT_V = "1.5.0 "
39+ UPPER_DBT_V = "1.4.2 "
3140
3241
3342@dataclass
@@ -54,37 +63,29 @@ def dbt_diff(
5463 if not is_cloud :
5564 dbt_parser .set_connection ()
5665
57- if config_prod_database is None or config_prod_schema is None :
58- raise ValueError ("Expected a value for prod_database: or prod_schema: under \n vars:\n data_diff: " )
66+ if config_prod_database is None :
67+ raise ValueError (
68+ "Expected a value for prod_database: OR prod_database: AND prod_schema: under \n vars:\n data_diff: "
69+ )
5970
6071 for model in models :
6172 diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , datasource_id )
6273
6374 if is_cloud and len (diff_vars .primary_keys ) > 0 :
6475 _cloud_diff (diff_vars )
65- elif is_cloud :
66- rich .print (
67- "[red]"
68- + "." .join (diff_vars .dev_path )
69- + " <> "
70- + "." .join (diff_vars .prod_path )
71- + "[/] \n "
72- + "Skipped due to missing primary-key tag\n "
73- )
74-
75- if not is_cloud and len (diff_vars .primary_keys ) == 1 :
76+ elif not is_cloud and len (diff_vars .primary_keys ) > 0 :
7677 _local_diff (diff_vars )
77- elif not is_cloud :
78+ else :
7879 rich .print (
7980 "[red]"
80- + "." .join (diff_vars .dev_path )
81- + " <> "
8281 + "." .join (diff_vars .prod_path )
82+ + " <> "
83+ + "." .join (diff_vars .dev_path )
8384 + "[/] \n "
84- + "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs) \n "
85+ + "Skipped due to missing primary-key tag(s). \n "
8586 )
8687
87- rich .print ("Diffs Complete!" )
88+ rich .print ("Diffs Complete!" )
8889
8990
9091def _get_diff_vars (
@@ -102,12 +103,12 @@ def _get_diff_vars(
102103 prod_schema = config_prod_schema if config_prod_schema else dev_schema
103104
104105 if dbt_parser .requires_upper :
105- dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .name ]]
106- prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .name ]]
106+ dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .alias ]]
107+ prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .alias ]]
107108 primary_keys = [x .upper () for x in primary_keys ]
108109 else :
109- dev_qualified_list = [dev_database , dev_schema , model .name ]
110- prod_qualified_list = [prod_database , prod_schema , model .name ]
110+ dev_qualified_list = [dev_database , dev_schema , model .alias ]
111+ prod_qualified_list = [prod_database , prod_schema , model .alias ]
111112
112113 return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection )
113114
@@ -116,10 +117,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
116117 column_diffs_str = ""
117118 dev_qualified_string = "." .join (diff_vars .dev_path )
118119 prod_qualified_string = "." .join (diff_vars .prod_path )
119- primary_key = diff_vars .primary_keys [0 ]
120120
121- table1 = connect_to_table (diff_vars .connection , dev_qualified_string , primary_key )
122- table2 = connect_to_table (diff_vars .connection , prod_qualified_string , primary_key )
121+ table1 = connect_to_table (diff_vars .connection , dev_qualified_string , tuple ( diff_vars . primary_keys ) )
122+ table2 = connect_to_table (diff_vars .connection , prod_qualified_string , tuple ( diff_vars . primary_keys ) )
123123
124124 table1_columns = list (table1 .get_schema ())
125125 try :
@@ -129,9 +129,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
129129 logging .info (ex )
130130 rich .print (
131131 "[red]"
132- + dev_qualified_string
133- + " <> "
134132 + prod_qualified_string
133+ + " <> "
134+ + dev_qualified_string
135135 + "[/] \n "
136136 + column_diffs_str
137137 + "[green]New model or no access to prod table.[/] \n "
@@ -143,22 +143,22 @@ def _local_diff(diff_vars: DiffVars) -> None:
143143 table2_set_diff = list (set (table2_columns ) - set (table1_columns ))
144144
145145 if table1_set_diff :
146- column_diffs_str += "Columns exclusive to table A : " + str (table1_set_diff ) + "\n "
146+ column_diffs_str += "Column(s) added : " + str (table1_set_diff ) + "\n "
147147
148148 if table2_set_diff :
149- column_diffs_str += "Columns exclusive to table B : " + str (table2_set_diff ) + "\n "
149+ column_diffs_str += "Column(s) removed : " + str (table2_set_diff ) + "\n "
150150
151- mutual_set . discard ( primary_key )
151+ mutual_set = mutual_set - set ( diff_vars . primary_keys )
152152 extra_columns = tuple (mutual_set )
153153
154154 diff = diff_tables (table1 , table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = extra_columns )
155155
156156 if list (diff ):
157157 rich .print (
158158 "[red]"
159- + dev_qualified_string
160- + " <> "
161159 + prod_qualified_string
160+ + " <> "
161+ + dev_qualified_string
162162 + "[/] \n "
163163 + column_diffs_str
164164 + diff .get_stats_string (is_dbt = True )
@@ -167,9 +167,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
167167 else :
168168 rich .print (
169169 "[red]"
170- + dev_qualified_string
171- + " <> "
172170 + prod_qualified_string
171+ + " <> "
172+ + dev_qualified_string
173173 + "[/] \n "
174174 + column_diffs_str
175175 + "[green]No row differences[/] \n "
@@ -191,8 +191,8 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
191191 payload = {
192192 "data_source1_id" : diff_vars .datasource_id ,
193193 "data_source2_id" : diff_vars .datasource_id ,
194- "table1" : diff_vars .dev_path ,
195- "table2" : diff_vars .prod_path ,
194+ "table1" : diff_vars .prod_path ,
195+ "table2" : diff_vars .dev_path ,
196196 "pk_columns" : diff_vars .primary_keys ,
197197 }
198198
@@ -216,9 +216,9 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
216216 diff_url = f"https://app.datafold.com/datadiffs/{ diff_id } /overview"
217217 rich .print (
218218 "[red]"
219- + "." .join (diff_vars .dev_path )
220- + " <> "
221219 + "." .join (diff_vars .prod_path )
220+ + " <> "
221+ + "." .join (diff_vars .dev_path )
222222 + "[/] \n Diff in progress: \n "
223223 + diff_url
224224 + "\n "
@@ -261,13 +261,15 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
261261 self .project_dict = None
262262 self .requires_upper = False
263263
264+ self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
265+
264266 def get_datadiff_variables (self ) -> dict :
265267 return self .project_dict .get ("vars" ).get ("data_diff" )
266268
267269 def get_models (self ):
268270 with open (self .project_dir + RUN_RESULTS_PATH ) as run_results :
269271 run_results_dict = json .load (run_results )
270- run_results_obj = parse_run_results (run_results = run_results_dict )
272+ run_results_obj = self . parse_run_results (run_results = run_results_dict )
271273
272274 dbt_version = parse_version (run_results_obj .metadata .dbt_version )
273275
@@ -278,7 +280,7 @@ def get_models(self):
278280
279281 with open (self .project_dir + MANIFEST_PATH ) as manifest :
280282 manifest_dict = json .load (manifest )
281- manifest_obj = parse_manifest (manifest = manifest_dict )
283+ manifest_obj = self . parse_manifest (manifest = manifest_dict )
282284
283285 success_models = [x .unique_id for x in run_results_obj .results if x .status .name == "success" ]
284286 models = [manifest_obj .nodes .get (x ) for x in success_models ]
@@ -293,20 +295,42 @@ def get_primary_keys(self, model):
293295
294296 def set_project_dict (self ):
295297 with open (self .project_dir + PROJECT_FILE ) as project :
296- self .project_dict = yaml .safe_load (project )
298+ self .project_dict = self . yaml .safe_load (project )
297299
298- def set_connection (self ):
299- with open (self .profiles_dir + PROFILES_FILE ) as profiles :
300- profiles = yaml .safe_load (profiles )
300+ def _get_connection_creds (self ) -> Tuple [Dict [str , str ], str ]:
301+ profiles_path = self .profiles_dir + PROFILES_FILE
302+ with open (profiles_path ) as profiles :
303+ profiles = self .yaml .safe_load (profiles )
301304
302305 dbt_profile = self .project_dict .get ("profile" )
303- profile_outputs = profiles .get (dbt_profile )
304- profile_target = profile_outputs .get ("target" )
305- credentials = profile_outputs .get ("outputs" ).get (profile_target )
306- conn_type = credentials .get ("type" ).lower ()
306+
307+ profile_outputs = get_from_dict_with_raise (
308+ profiles , dbt_profile , f"No profile '{ dbt_profile } ' found in '{ profiles_path } '."
309+ )
310+ profile_target = get_from_dict_with_raise (
311+ profile_outputs , "target" , f"No target found in profile '{ dbt_profile } ' in '{ profiles_path } '."
312+ )
313+ outputs = get_from_dict_with_raise (
314+ profile_outputs , "outputs" , f"No outputs found in profile '{ dbt_profile } ' in '{ profiles_path } '."
315+ )
316+ credentials = get_from_dict_with_raise (
317+ outputs ,
318+ profile_target ,
319+ f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
320+ )
321+ conn_type = get_from_dict_with_raise (
322+ credentials ,
323+ "type" ,
324+ f"No type found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
325+ )
326+ conn_type = conn_type .lower ()
307327
308328 # values can contain env_vars
309- rendered_credentials = ProfileRenderer ().render_data (credentials )
329+ rendered_credentials = self .ProfileRenderer ().render_data (credentials )
330+ return rendered_credentials , conn_type
331+
332+ def set_connection (self ):
333+ rendered_credentials , conn_type = self ._get_connection_creds ()
310334
311335 if conn_type == "snowflake" :
312336 if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("private_key_path" ) is not None :
@@ -337,6 +361,34 @@ def set_connection(self):
337361 conn_info = {
338362 "driver" : conn_type ,
339363 "filepath" : rendered_credentials .get ("path" ),
364+ elif conn_type == "redshift" :
365+ if rendered_credentials .get ("password" ) is None or rendered_credentials .get ("method" ) == "iam" :
366+ raise Exception ("Only password authentication is currently supported for Redshift." )
367+ conn_info = {
368+ "driver" : conn_type ,
369+ "host" : rendered_credentials .get ("host" ),
370+ "user" : rendered_credentials .get ("user" ),
371+ "password" : rendered_credentials .get ("password" ),
372+ "port" : rendered_credentials .get ("port" ),
373+ "dbname" : rendered_credentials .get ("dbname" ),
374+ }
375+ elif conn_type == "databricks" :
376+ conn_info = {
377+ "driver" : conn_type ,
378+ "catalog" : rendered_credentials .get ("catalog" ),
379+ "server_hostname" : rendered_credentials .get ("host" ),
380+ "http_path" : rendered_credentials .get ("http_path" ),
381+ "schema" : rendered_credentials .get ("schema" ),
382+ "access_token" : rendered_credentials .get ("token" ),
383+ }
384+ elif conn_type == "postgres" :
385+ conn_info = {
386+ "driver" : "postgresql" ,
387+ "host" : rendered_credentials .get ("host" ),
388+ "user" : rendered_credentials .get ("user" ),
389+ "password" : rendered_credentials .get ("password" ),
390+ "port" : rendered_credentials .get ("port" ),
391+ "dbname" : rendered_credentials .get ("dbname" ) or rendered_credentials .get ("database" ),
340392 }
341393 else :
342394 raise NotImplementedError (f"Provider { conn_type } is not yet supported for dbt diffs" )
0 commit comments