99from dataclasses import dataclass
1010from packaging .version import parse as parse_version
1111from typing import List , Optional , Dict , Tuple , Set
12- from .utils import getLogger
12+ from .utils import dbt_diff_string_template , getLogger
1313from .version import __version__
1414from pathlib import Path
1515
@@ -69,16 +69,16 @@ class DiffVars:
6969 dev_path : List [str ]
7070 prod_path : List [str ]
7171 primary_keys : List [str ]
72- datasource_id : str
7372 connection : Dict [str , str ]
7473 threads : Optional [int ]
7574
7675
7776def dbt_diff (
7877 profiles_dir_override : Optional [str ] = None , project_dir_override : Optional [str ] = None , is_cloud : bool = False
7978) -> None :
79+ diff_threads = []
8080 set_entrypoint_name ("CLI-dbt" )
81- dbt_parser = DbtParser (profiles_dir_override , project_dir_override , is_cloud )
81+ dbt_parser = DbtParser (profiles_dir_override , project_dir_override )
8282 models = dbt_parser .get_models ()
8383 datadiff_variables = dbt_parser .get_datadiff_variables ()
8484 config_prod_database = datadiff_variables .get ("prod_database" )
@@ -89,7 +89,17 @@ def dbt_diff(
8989 custom_schemas = True if custom_schemas is None else custom_schemas
9090 set_dbt_user_id (dbt_parser .dbt_user_id )
9191
92- if not is_cloud :
92+ if is_cloud :
93+ if datasource_id is None :
94+ raise ValueError (
95+ "Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \n vars:\n data_diff:\n datasource_id: 1234"
96+ )
97+ datafold_host , url , api_key = _setup_cloud_diff ()
98+
99+ # exit so the user can set the key
100+ if not api_key :
101+ return
102+ else :
93103 dbt_parser .set_connection ()
94104
95105 if config_prod_database is None :
@@ -98,14 +108,14 @@ def dbt_diff(
98108 )
99109
100110 for model in models :
101- diff_vars = _get_diff_vars (
102- dbt_parser , config_prod_database , config_prod_schema , model , datasource_id , custom_schemas
103- )
104-
105- if is_cloud and len ( diff_vars . primary_keys ) > 0 :
106- _cloud_diff ( diff_vars )
107- elif not is_cloud and len ( diff_vars . primary_keys ) > 0 :
108- _local_diff (diff_vars )
111+ diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , custom_schemas )
112+
113+ if diff_vars . primary_keys :
114+ if is_cloud :
115+ diff_thread = run_as_daemon ( _cloud_diff , diff_vars , datasource_id , datafold_host , url , api_key )
116+ diff_threads . append ( diff_thread )
117+ else :
118+ _local_diff (diff_vars )
109119 else :
110120 rich .print (
111121 "[red]"
@@ -116,6 +126,11 @@ def dbt_diff(
116126 + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
117127 )
118128
129+ # wait for all threads
130+ if diff_threads :
131+ for thread in diff_threads :
132+ thread .join ()
133+
119134 rich .print ("Diffs Complete!" )
120135
121136
@@ -124,7 +139,6 @@ def _get_diff_vars(
124139 config_prod_database : Optional [str ],
125140 config_prod_schema : Optional [str ],
126141 model ,
127- datasource_id : int ,
128142 custom_schemas : bool ,
129143) -> DiffVars :
130144 dev_database = model .database
@@ -149,9 +163,7 @@ def _get_diff_vars(
149163 dev_qualified_list = [dev_database , dev_schema , model .alias ]
150164 prod_qualified_list = [prod_database , prod_schema , model .alias ]
151165
152- return DiffVars (
153- dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads
154- )
166+ return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , dbt_parser .connection , dbt_parser .threads )
155167
156168
157169def _local_diff (diff_vars : DiffVars ) -> None :
@@ -221,33 +233,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
221233 )
222234
223235
224- def _cloud_diff (diff_vars : DiffVars ) -> None :
225- datafold_host = os .environ .get ("DATAFOLD_HOST" )
226- if datafold_host is None :
227- datafold_host = "https://app.datafold.com"
228- datafold_host = datafold_host .rstrip ("/" )
229- rich .print (f"Cloud datafold host: { datafold_host } " )
230-
231- api_key = os .environ .get ("DATAFOLD_API_KEY" )
232- if not api_key :
233- rich .print ("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY." )
234- yes_or_no = Confirm .ask ("Would you like to generate a new API key?" )
235- if yes_or_no :
236- webbrowser .open (f"{ datafold_host } /login?next={ datafold_host } /users/me" )
237- return
238- else :
239- raise ValueError ("Cannot diff because the API key is not provided" )
240-
241- if diff_vars .datasource_id is None :
242- raise ValueError (
243- "Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \n vars:\n data_diff:\n datasource_id: 1234"
244- )
245-
246- url = f"{ datafold_host } /api/v1/datadiffs"
247-
236+ def _cloud_diff (diff_vars : DiffVars , datasource_id : int , datafold_host : str , url : str , api_key : str ) -> None :
248237 payload = {
249- "data_source1_id" : diff_vars . datasource_id ,
250- "data_source2_id" : diff_vars . datasource_id ,
238+ "data_source1_id" : datasource_id ,
239+ "data_source2_id" : datasource_id ,
251240 "table1" : diff_vars .prod_path ,
252241 "table2" : diff_vars .dev_path ,
253242 "pk_columns" : diff_vars .primary_keys ,
@@ -258,27 +247,60 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
258247 "Content-Type" : "application/json" ,
259248 }
260249 if is_tracking_enabled ():
261- event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : diff_vars . datasource_id })
250+ event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : datasource_id })
262251 run_as_daemon (send_event_json , event_json )
263252
264253 start = time .monotonic ()
265254 error = None
266255 diff_id = None
256+ diff_url = None
267257 try :
268- response = requests . request ( "POST" , url , headers = headers , json = payload , timeout = 30 )
269- response . raise_for_status ()
270- data = response . json ( )
271- diff_id = data [ "id" ]
258+ diff_id = _cloud_submit_diff ( url , payload , headers )
259+ summary_url = f" { url } / { diff_id } /summary_results"
260+ diff_results = _cloud_poll_and_get_summary_results ( summary_url , headers )
261+
272262 diff_url = f"{ datafold_host } /datadiffs/{ diff_id } /overview"
273- rich .print (
274- "[red]"
275- + "." .join (diff_vars .prod_path )
276- + " <> "
277- + "." .join (diff_vars .dev_path )
278- + "[/] \n Diff in progress: \n "
279- + diff_url
280- + "\n "
281- )
263+
264+ rows_added_count = diff_results ["pks" ]["exclusives" ][1 ]
265+ rows_removed_count = diff_results ["pks" ]["exclusives" ][0 ]
266+
267+ rows_updated = diff_results ["values" ]["rows_with_differences" ]
268+ total_rows = diff_results ["values" ]["total_rows" ]
269+ rows_unchanged = int (total_rows ) - int (rows_updated )
270+ diff_percent_list = {
271+ x ["column_name" ]: str (x ["match" ]) + "%"
272+ for x in diff_results ["values" ]["columns_diff_stats" ]
273+ if x ["match" ] != 100.0
274+ }
275+
276+ if any ([rows_added_count , rows_removed_count , rows_updated ]):
277+ diff_output = dbt_diff_string_template (
278+ rows_added_count ,
279+ rows_removed_count ,
280+ rows_updated ,
281+ str (rows_unchanged ),
282+ diff_percent_list ,
283+ "Value Match Percent:" ,
284+ )
285+ rich .print (
286+ "[red]"
287+ + "." .join (diff_vars .prod_path )
288+ + " <> "
289+ + "." .join (diff_vars .dev_path )
290+ + f"[/]\n { diff_url } \n "
291+ + diff_output
292+ + "\n "
293+ )
294+ else :
295+ rich .print (
296+ "[red]"
297+ + "." .join (diff_vars .prod_path )
298+ + " <> "
299+ + "." .join (diff_vars .dev_path )
300+ + f"[/]\n { diff_url } \n "
301+ + "[green]No row differences[/] \n "
302+ )
303+
282304 except BaseException as ex : # Catch KeyboardInterrupt too
283305 error = ex
284306 finally :
@@ -302,15 +324,81 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
302324 send_event_json (event_json )
303325
304326 if error :
305- raise error
327+ rich .print (
328+ "[red]"
329+ + "." .join (diff_vars .prod_path )
330+ + " <> "
331+ + "." .join (diff_vars .dev_path ) + "[/]\n "
332+ )
333+ if diff_id :
334+ diff_url = f"{ datafold_host } /datadiffs/{ diff_id } /overview"
335+ rich .print (f"{ diff_url } \n " )
336+ logger .error (error )
337+
338+
339+ def _setup_cloud_diff () -> Tuple [Optional [str ]]:
340+ datafold_host = os .environ .get ("DATAFOLD_HOST" )
341+ if datafold_host is None :
342+ datafold_host = "https://app.datafold.com"
343+ datafold_host = datafold_host .rstrip ("/" )
344+ rich .print (f"Cloud datafold host: { datafold_host } \n " )
345+ url = f"{ datafold_host } /api/v1/datadiffs"
346+
347+ api_key = os .environ .get ("DATAFOLD_API_KEY" )
348+ if not api_key :
349+ rich .print ("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY." )
350+ yes_or_no = Confirm .ask ("Would you like to generate a new API key?" )
351+ if yes_or_no :
352+ webbrowser .open (f"{ datafold_host } /login?next={ datafold_host } /users/me" )
353+ return None , None , None
354+ else :
355+ raise ValueError ("Cannot diff because the API key is not provided" )
356+
357+ return datafold_host , url , api_key
358+
359+
360+ def _cloud_submit_diff (url , payload , headers ) -> str :
361+ response = requests .request ("POST" , url , headers = headers , json = payload , timeout = 30 )
362+ response .raise_for_status ()
363+ response_json = response .json ()
364+ diff_id = str (response_json ["id" ])
365+
366+ if diff_id is None :
367+ raise Exception (f"Api response did not contain a diff_id: { str (response_json )} " )
368+ return diff_id
369+
370+
371+ def _cloud_poll_and_get_summary_results (url , headers ):
372+ summary_results = None
373+ start_time = time .time ()
374+ sleep_interval = 5 # starts at 5 sec
375+ max_sleep_interval = 60
376+ max_wait_time = 300
377+
378+ while not summary_results :
379+ response = requests .request ("GET" , url , headers = headers , timeout = 30 )
380+ response .raise_for_status ()
381+ response_json = response .json ()
382+
383+ if response_json ["status" ] == "success" :
384+ summary_results = response_json
385+ elif response_json ["status" ] == "failed" :
386+ raise Exception (f"Diff failed: { str (response_json )} " )
387+
388+ if time .time () - start_time > max_wait_time :
389+ raise Exception ("Timed out waiting for diff results" )
390+
391+ time .sleep (sleep_interval )
392+ sleep_interval = min (sleep_interval * 2 , max_sleep_interval )
393+
394+ return summary_results
306395
307396
308397class DbtParser :
309- def __init__ (self , profiles_dir_override : str , project_dir_override : str , is_cloud : bool ) -> None :
398+ def __init__ (self , profiles_dir_override : str , project_dir_override : str ) -> None :
310399 self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
311400 self .profiles_dir = Path (profiles_dir_override or default_profiles_dir ())
312401 self .project_dir = Path (project_dir_override or default_project_dir ())
313- self .is_cloud = is_cloud
314402 self .connection = None
315403 self .project_dict = self .get_project_dict ()
316404 self .manifest_obj = self .get_manifest_obj ()
0 commit comments