1+ import json
12import time
23from typing import List , Optional , Union , overload
34
@@ -50,14 +51,33 @@ def _validate_temp_schema(temp_schema: str):
5051 raise ValueError ("Temporary schema should have a format <database>.<schema>" )
5152
5253
54+ def _get_temp_schema (dbt_parser : DbtParser , db_type : str ) -> Optional [str ]:
55+ diff_vars = dbt_parser .get_datadiff_variables ()
56+ config_prod_database = diff_vars .get ("prod_database" )
57+ config_prod_schema = diff_vars .get ("prod_schema" )
58+ if config_prod_database is not None and config_prod_schema is not None :
59+ temp_schema = f"{ config_prod_database } .{ config_prod_schema } "
60+ if db_type == "snowflake" :
61+ return temp_schema .upper ()
62+ elif db_type in {"pg" , "postgres_aurora" , "postgres_aws_rds" , "redshift" }:
63+ return temp_schema .lower ()
64+ return temp_schema
65+ return
66+
67+
5368def create_ds_config (
5469 ds_config : TCloudApiDataSourceConfigSchema ,
5570 data_source_name : str ,
5671 dbt_parser : Optional [DbtParser ] = None ,
5772) -> TDsConfig :
5873 options = _parse_ds_credentials (ds_config = ds_config , only_basic_settings = True , dbt_parser = dbt_parser )
5974
60- temp_schema = TemporarySchemaPrompt .ask ("Temporary schema (<database>.<schema>)" )
75+ temp_schema = _get_temp_schema (dbt_parser = dbt_parser , db_type = ds_config .db_type ) if dbt_parser else None
76+ if temp_schema :
77+ temp_schema = TemporarySchemaPrompt .ask ("Temporary schema" , default = temp_schema )
78+ else :
79+ temp_schema = TemporarySchemaPrompt .ask ("Temporary schema (<database>.<schema>)" )
80+
6181 float_tolerance = FloatPrompt .ask ("Float tolerance" , default = 0.000001 )
6282
6383 return TDsConfig (
@@ -92,6 +112,37 @@ def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
92112 return value
93113
94114
115+ def _get_data_from_bigquery_json (path : str ):
116+ with open (path , "r" ) as file :
117+ return json .load (file )
118+
119+
120+ def _align_dbt_cred_params_with_datafold_params (dbt_creds : dict ) -> dict :
121+ db_type = dbt_creds ["type" ]
122+ if db_type == "bigquery" :
123+ method = dbt_creds ["method" ]
124+ if method == "service-account" :
125+ data = _get_data_from_bigquery_json (path = dbt_creds ["keyfile" ])
126+ dbt_creds ["jsonKeyFile" ] = json .dumps (data )
127+ elif method == "service-account-json" :
128+ dbt_creds ["jsonKeyFile" ] = json .dumps (dbt_creds ["keyfile_json" ])
129+ else :
130+ rich .print (
131+ f'[red]Cannot extract bigquery credentials from dbt_project.yml for "{ method } " type. '
132+ f"If you want to provide credentials via dbt_project.yml, "
133+ f'please, use "service-account" or "service-account-json" '
134+ f"(more in docs: https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup). "
135+ f"Otherwise, you can provide a path to a json key file or a json key file data as an input."
136+ )
137+ dbt_creds ["projectId" ] = dbt_creds ["project" ]
138+ elif db_type == "snowflake" :
139+ dbt_creds ["default_db" ] = dbt_creds ["database" ]
140+ elif db_type == "databricks" :
141+ dbt_creds ["http_password" ] = dbt_creds ["token" ]
142+ dbt_creds ["database" ] = dbt_creds .get ("catalog" )
143+ return dbt_creds
144+
145+
95146def _parse_ds_credentials (
96147 ds_config : TCloudApiDataSourceConfigSchema , only_basic_settings : bool = True , dbt_parser : Optional [DbtParser ] = None
97148):
@@ -101,6 +152,7 @@ def _parse_ds_credentials(
101152 use_dbt_data = Confirm .ask ("Would you like to extract database credentials from dbt profiles.yml?" )
102153 try :
103154 creds = dbt_parser .get_connection_creds ()[0 ]
155+ creds = _align_dbt_cred_params_with_datafold_params (dbt_creds = creds )
104156 except Exception as e :
105157 rich .print (f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: { e } " )
106158
0 commit comments