11import json
22import logging
33import os
4+ import time
45import rich
5- import yaml
66from dataclasses import dataclass
77from packaging .version import parse as parse_version
8- from typing import List , Optional , Dict
8+ from typing import List , Optional , Dict , Tuple
99
1010import requests
11- from dbt_artifacts_parser .parser import parse_run_results , parse_manifest
12- from dbt .config .renderer import ProfileRenderer
1311
12+
13+ def import_dbt ():
14+ try :
15+ from dbt_artifacts_parser .parser import parse_run_results , parse_manifest
16+ from dbt .config .renderer import ProfileRenderer
17+ import yaml
18+ except ImportError :
19+ raise RuntimeError ("Could not import 'dbt' package. You can install it using: pip install 'data-diff[dbt]'." )
20+
21+ return parse_run_results , parse_manifest , ProfileRenderer , yaml
22+
23+
24+ from .tracking import (
25+ set_entrypoint_name ,
26+ create_end_event_json ,
27+ create_start_event_json ,
28+ send_event_json ,
29+ is_tracking_enabled ,
30+ )
31+ from .utils import get_from_dict_with_raise , run_as_daemon , truncate_error
1432from . import connect_to_table , diff_tables , Algorithm
1533
1634RUN_RESULTS_PATH = "/target/run_results.json"
@@ -33,6 +51,7 @@ class DiffVars:
3351def dbt_diff (
3452 profiles_dir_override : Optional [str ] = None , project_dir_override : Optional [str ] = None , is_cloud : bool = False
3553) -> None :
54+ set_entrypoint_name ("CLI-dbt" )
3655 dbt_parser = DbtParser (profiles_dir_override , project_dir_override , is_cloud )
3756 models = dbt_parser .get_models ()
3857 dbt_parser .set_project_dict ()
@@ -44,8 +63,10 @@ def dbt_diff(
4463 if not is_cloud :
4564 dbt_parser .set_connection ()
4665
47- if config_prod_database is None or config_prod_schema is None :
48- raise ValueError ("Expected a value for prod_database: or prod_schema: under \n vars:\n data_diff: " )
66+ if config_prod_database is None :
67+ raise ValueError (
68+ "Expected a value for prod_database: OR prod_database: AND prod_schema: under \n vars:\n data_diff: "
69+ )
4970
5071 for model in models :
5172 diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , datasource_id )
@@ -55,9 +76,9 @@ def dbt_diff(
5576 elif is_cloud :
5677 rich .print (
5778 "[red]"
58- + "." .join (diff_vars .dev_path )
59- + " <> "
6079 + "." .join (diff_vars .prod_path )
80+ + " <> "
81+ + "." .join (diff_vars .dev_path )
6182 + "[/] \n "
6283 + "Skipped due to missing primary-key tag\n "
6384 )
@@ -67,14 +88,14 @@ def dbt_diff(
6788 elif not is_cloud :
6889 rich .print (
6990 "[red]"
70- + "." .join (diff_vars .dev_path )
71- + " <> "
7291 + "." .join (diff_vars .prod_path )
92+ + " <> "
93+ + "." .join (diff_vars .dev_path )
7394 + "[/] \n "
7495 + "Skipped due to missing primary-key tag or multi-column primary-key (unsupported for non --cloud diffs)\n "
7596 )
7697
77- rich .print ("Diffs Complete!" )
98+ rich .print ("Diffs Complete!" )
7899
79100
80101def _get_diff_vars (
@@ -92,12 +113,12 @@ def _get_diff_vars(
92113 prod_schema = config_prod_schema if config_prod_schema else dev_schema
93114
94115 if dbt_parser .requires_upper :
95- dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .name ]]
96- prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .name ]]
116+ dev_qualified_list = [x .upper () for x in [dev_database , dev_schema , model .alias ]]
117+ prod_qualified_list = [x .upper () for x in [prod_database , prod_schema , model .alias ]]
97118 primary_keys = [x .upper () for x in primary_keys ]
98119 else :
99- dev_qualified_list = [dev_database , dev_schema , model .name ]
100- prod_qualified_list = [prod_database , prod_schema , model .name ]
120+ dev_qualified_list = [dev_database , dev_schema , model .alias ]
121+ prod_qualified_list = [prod_database , prod_schema , model .alias ]
101122
102123 return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection )
103124
@@ -119,9 +140,9 @@ def _local_diff(diff_vars: DiffVars) -> None:
119140 logging .info (ex )
120141 rich .print (
121142 "[red]"
122- + dev_qualified_string
123- + " <> "
124143 + prod_qualified_string
144+ + " <> "
145+ + dev_qualified_string
125146 + "[/] \n "
126147 + column_diffs_str
127148 + "[green]New model or no access to prod table.[/] \n "
@@ -133,10 +154,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
133154 table2_set_diff = list (set (table2_columns ) - set (table1_columns ))
134155
135156 if table1_set_diff :
136- column_diffs_str += "Columns exclusive to table A : " + str (table1_set_diff ) + "\n "
157+ column_diffs_str += "Column(s) added : " + str (table1_set_diff ) + "\n "
137158
138159 if table2_set_diff :
139- column_diffs_str += "Columns exclusive to table B : " + str (table2_set_diff ) + "\n "
160+ column_diffs_str += "Column(s) removed : " + str (table2_set_diff ) + "\n "
140161
141162 mutual_set .discard (primary_key )
142163 extra_columns = tuple (mutual_set )
@@ -146,20 +167,20 @@ def _local_diff(diff_vars: DiffVars) -> None:
146167 if list (diff ):
147168 rich .print (
148169 "[red]"
149- + dev_qualified_string
150- + " <> "
151170 + prod_qualified_string
171+ + " <> "
172+ + dev_qualified_string
152173 + "[/] \n "
153174 + column_diffs_str
154- + diff .get_stats_string ()
175+ + diff .get_stats_string (is_dbt = True )
155176 + "\n "
156177 )
157178 else :
158179 rich .print (
159180 "[red]"
160- + dev_qualified_string
161- + " <> "
162181 + prod_qualified_string
182+ + " <> "
183+ + dev_qualified_string
163184 + "[/] \n "
164185 + column_diffs_str
165186 + "[green]No row differences[/] \n "
@@ -181,31 +202,62 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
181202 payload = {
182203 "data_source1_id" : diff_vars .datasource_id ,
183204 "data_source2_id" : diff_vars .datasource_id ,
184- "table1" : diff_vars .dev_path ,
185- "table2" : diff_vars .prod_path ,
205+ "table1" : diff_vars .prod_path ,
206+ "table2" : diff_vars .dev_path ,
186207 "pk_columns" : diff_vars .primary_keys ,
187208 }
188209
189210 headers = {
190211 "Authorization" : f"Key { api_key } " ,
191212 "Content-Type" : "application/json" ,
192213 }
214+ if is_tracking_enabled ():
215+ event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : diff_vars .datasource_id })
216+ run_as_daemon (send_event_json , event_json )
217+
218+ start = time .monotonic ()
219+ error = None
220+ diff_id = None
221+ try :
222+ response = requests .request ("POST" , url , headers = headers , json = payload , timeout = 30 )
223+ response .raise_for_status ()
224+ data = response .json ()
225+ diff_id = data ["id" ]
226+ # TODO in future we should support self hosted datafold
227+ diff_url = f"https://app.datafold.com/datadiffs/{ diff_id } /overview"
228+ rich .print (
229+ "[red]"
230+ + "." .join (diff_vars .prod_path )
231+ + " <> "
232+ + "." .join (diff_vars .dev_path )
233+ + "[/] \n Diff in progress: \n "
234+ + diff_url
235+ + "\n "
236+ )
237+ except BaseException as ex : # Catch KeyboardInterrupt too
238+ error = ex
239+ finally :
240+ # we don't currently have much of this information
241+ # but I imagine a future iteration of this _cloud method
242+ # will poll for results
243+ if is_tracking_enabled ():
244+ err_message = truncate_error (repr (error ))
245+ event_json = create_end_event_json (
246+ is_success = error is None ,
247+ runtime_seconds = time .monotonic () - start ,
248+ data_source_1_type = "" ,
249+ data_source_2_type = "" ,
250+ table1_count = 0 ,
251+ table2_count = 0 ,
252+ diff_count = 0 ,
253+ error = err_message ,
254+ diff_id = diff_id ,
255+ is_cloud = True ,
256+ )
257+ send_event_json (event_json )
193258
194- response = requests .request ("POST" , url , headers = headers , json = payload , timeout = 30 )
195- response .raise_for_status ()
196- data = response .json ()
197- diff_id = data ["id" ]
198- # TODO in future we should support self hosted datafold
199- diff_url = f"https://app.datafold.com/datadiffs/{ diff_id } /overview"
200- rich .print (
201- "[red]"
202- + "." .join (diff_vars .dev_path )
203- + " <> "
204- + "." .join (diff_vars .prod_path )
205- + "[/] \n Diff in progress: \n "
206- + diff_url
207- + "\n "
208- )
259+ if error :
260+ raise error
209261
210262
211263class DbtParser :
@@ -220,25 +272,26 @@ def __init__(self, profiles_dir_override: str, project_dir_override: str, is_clo
220272 self .project_dict = None
221273 self .requires_upper = False
222274
275+ self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
276+
223277 def get_datadiff_variables (self ) -> dict :
224278 return self .project_dict .get ("vars" ).get ("data_diff" )
225279
226280 def get_models (self ):
227281 with open (self .project_dir + RUN_RESULTS_PATH ) as run_results :
228282 run_results_dict = json .load (run_results )
229- run_results_obj = parse_run_results (run_results = run_results_dict )
283+ run_results_obj = self . parse_run_results (run_results = run_results_dict )
230284
231285 dbt_version = parse_version (run_results_obj .metadata .dbt_version )
232286
233- # TODO 1.4 support
234287 if dbt_version < parse_version (LOWER_DBT_V ) or dbt_version >= parse_version (UPPER_DBT_V ):
235288 raise Exception (
236289 f"Found dbt: v{ dbt_version } Expected the dbt project's version to be >= { LOWER_DBT_V } and < { UPPER_DBT_V } "
237290 )
238291
239292 with open (self .project_dir + MANIFEST_PATH ) as manifest :
240293 manifest_dict = json .load (manifest )
241- manifest_obj = parse_manifest (manifest = manifest_dict )
294+ manifest_obj = self . parse_manifest (manifest = manifest_dict )
242295
243296 success_models = [x .unique_id for x in run_results_obj .results if x .status .name == "success" ]
244297 models = [manifest_obj .nodes .get (x ) for x in success_models ]
@@ -253,20 +306,42 @@ def get_primary_keys(self, model):
253306
254307 def set_project_dict (self ):
255308 with open (self .project_dir + PROJECT_FILE ) as project :
256- self .project_dict = yaml .safe_load (project )
309+ self .project_dict = self . yaml .safe_load (project )
257310
258- def set_connection (self ):
259- with open (self .profiles_dir + PROFILES_FILE ) as profiles :
260- profiles = yaml .safe_load (profiles )
311+ def _get_connection_creds (self ) -> Tuple [Dict [str , str ], str ]:
312+ profiles_path = self .profiles_dir + PROFILES_FILE
313+ with open (profiles_path ) as profiles :
314+ profiles = self .yaml .safe_load (profiles )
261315
262316 dbt_profile = self .project_dict .get ("profile" )
263- profile_outputs = profiles .get (dbt_profile )
264- profile_target = profile_outputs .get ("target" )
265- credentials = profile_outputs .get ("outputs" ).get (profile_target )
266- conn_type = credentials .get ("type" ).lower ()
317+
318+ profile_outputs = get_from_dict_with_raise (
319+ profiles , dbt_profile , f"No profile '{ dbt_profile } ' found in '{ profiles_path } '."
320+ )
321+ profile_target = get_from_dict_with_raise (
322+ profile_outputs , "target" , f"No target found in profile '{ dbt_profile } ' in '{ profiles_path } '."
323+ )
324+ outputs = get_from_dict_with_raise (
325+ profile_outputs , "outputs" , f"No outputs found in profile '{ dbt_profile } ' in '{ profiles_path } '."
326+ )
327+ credentials = get_from_dict_with_raise (
328+ outputs ,
329+ profile_target ,
330+ f"No credentials found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
331+ )
332+ conn_type = get_from_dict_with_raise (
333+ credentials ,
334+ "type" ,
335+ f"No type found for target '{ profile_target } ' in profile '{ dbt_profile } ' in '{ profiles_path } '." ,
336+ )
337+ conn_type = conn_type .lower ()
267338
268339 # values can contain env_vars
269- rendered_credentials = ProfileRenderer ().render_data (credentials )
340+ rendered_credentials = self .ProfileRenderer ().render_data (credentials )
341+ return rendered_credentials , conn_type
342+
343+ def set_connection (self ):
344+ rendered_credentials , conn_type = self ._get_connection_creds ()
270345
271346 # this whole block should be refactored/extracted to method(s)
272347 if conn_type == "snowflake" :
0 commit comments