77from dataclasses import dataclass
88from packaging .version import parse as parse_version
99from typing import List , Optional , Dict
10+ from pathlib import Path
1011
1112import requests
1213from dbt_artifacts_parser .parser import parse_run_results , parse_manifest
2223from .utils import run_as_daemon , truncate_error
2324from . import connect_to_table , diff_tables , Algorithm
2425
25- RUN_RESULTS_PATH = "/ target/run_results.json"
26- MANIFEST_PATH = "/ target/manifest.json"
27- PROJECT_FILE = "/ dbt_project.yml"
28- PROFILES_FILE = "/ profiles.yml"
26+ RUN_RESULTS_PATH = "target/run_results.json"
27+ MANIFEST_PATH = "target/manifest.json"
28+ PROJECT_FILE = "dbt_project.yml"
29+ PROFILES_FILE = "profiles.yml"
2930LOWER_DBT_V = "1.0.0"
3031UPPER_DBT_V = "1.5.0"
3132
3233
34+ # https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L6
35+ def default_project_dir () -> Path :
36+ paths = list (Path .cwd ().parents )
37+ paths .insert (0 , Path .cwd ())
38+ return next ((x for x in paths if (x / PROJECT_FILE ).exists ()), Path .cwd ())
39+
40+
41+ # https://github.com/dbt-labs/dbt-core/blob/c952d44ec5c2506995fbad75320acbae49125d3d/core/dbt/cli/resolvers.py#L12
42+ def default_profiles_dir () -> Path :
43+ return Path .cwd () if (Path .cwd () / PROFILES_FILE ).exists () else Path .home () / ".dbt"
44+
45+
3346@dataclass
3447class DiffVars :
3548 dev_path : List [str ]
@@ -252,12 +265,10 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
252265
253266
254267class DbtParser :
255- DEFAULT_PROFILES_DIR = os .path .expanduser ("~" ) + "/.dbt"
256- DEFAULT_PROJECT_DIR = os .getcwd ()
257268
258269 def __init__ (self , profiles_dir_override : str , project_dir_override : str , is_cloud : bool ) -> None :
259- self .profiles_dir = profiles_dir_override or self . DEFAULT_PROFILES_DIR
260- self .project_dir = project_dir_override or self . DEFAULT_PROJECT_DIR
270+ self .profiles_dir = Path ( profiles_dir_override or default_profiles_dir ())
271+ self .project_dir = Path ( project_dir_override or default_project_dir ())
261272 self .is_cloud = is_cloud
262273 self .connection = None
263274 self .project_dict = None
@@ -267,7 +278,7 @@ def get_datadiff_variables(self) -> dict:
267278 return self .project_dict .get ("vars" ).get ("data_diff" )
268279
269280 def get_models (self ):
270- with open (self .project_dir + RUN_RESULTS_PATH ) as run_results :
281+ with open (self .project_dir / RUN_RESULTS_PATH ) as run_results :
271282 run_results_dict = json .load (run_results )
272283 run_results_obj = parse_run_results (run_results = run_results_dict )
273284
@@ -278,7 +289,7 @@ def get_models(self):
278289 f"Found dbt: v{ dbt_version } Expected the dbt project's version to be >= { LOWER_DBT_V } and < { UPPER_DBT_V } "
279290 )
280291
281- with open (self .project_dir + MANIFEST_PATH ) as manifest :
292+ with open (self .project_dir / MANIFEST_PATH ) as manifest :
282293 manifest_dict = json .load (manifest )
283294 manifest_obj = parse_manifest (manifest = manifest_dict )
284295
@@ -294,11 +305,11 @@ def get_primary_keys(self, model):
294305 return list ((x .name for x in model .columns .values () if "primary-key" in x .tags ))
295306
296307 def set_project_dict (self ):
297- with open (self .project_dir + PROJECT_FILE ) as project :
308+ with open (self .project_dir / PROJECT_FILE ) as project :
298309 self .project_dict = yaml .safe_load (project )
299310
300311 def set_connection (self ):
301- with open (self .profiles_dir + PROFILES_FILE ) as profiles :
312+ with open (self .profiles_dir / PROFILES_FILE ) as profiles :
302313 profiles = yaml .safe_load (profiles )
303314
304315 dbt_profile = self .project_dict .get ("profile" )
0 commit comments