-
Notifications
You must be signed in to change notification settings - Fork 377
Expand file tree
/
Copy pathloader.py
More file actions
406 lines (335 loc) · 15.1 KB
/
loader.py
File metadata and controls
406 lines (335 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
from __future__ import annotations
import logging
import sys
import typing as t
import sqlmesh.core.dialect as d
from pathlib import Path
from sqlmesh.core.config import (
Config,
ConnectionConfig,
GatewayConfig,
ModelDefaultsConfig,
)
from sqlmesh.core.environment import EnvironmentStatements
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
from sqlmesh.core.macros import MacroRegistry, macro
from sqlmesh.core.model import Model, ModelCache
from sqlmesh.core.signal import signal
from sqlmesh.dbt.basemodel import BMC, BaseModelConfig
from sqlmesh.dbt.common import Dependencies
from sqlmesh.dbt.context import DbtContext
from sqlmesh.dbt.profile import Profile
from sqlmesh.dbt.project import Project
from sqlmesh.dbt.target import TargetConfig
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.jinja import (
JinjaMacroRegistry,
make_jinja_registry,
)
if sys.version_info >= (3, 12):
from importlib import metadata
else:
import importlib_metadata as metadata # type: ignore
if t.TYPE_CHECKING:
from sqlmesh.core.audit import Audit, ModelAudit
from sqlmesh.core.context import GenericContext
logger = logging.getLogger(__name__)
def sqlmesh_config(
project_root: t.Optional[Path] = None,
state_connection: t.Optional[ConnectionConfig] = None,
dbt_profile_name: t.Optional[str] = None,
dbt_target_name: t.Optional[str] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
register_comments: t.Optional[bool] = None,
**kwargs: t.Any,
) -> Config:
project_root = project_root or Path()
context = DbtContext(project_root=project_root, profile_name=dbt_profile_name)
profile = Profile.load(context, target_name=dbt_target_name)
model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig())
if model_defaults.dialect is None:
model_defaults.dialect = profile.target.dialect
target_to_sqlmesh_args = {}
if register_comments is not None:
target_to_sqlmesh_args["register_comments"] = register_comments
loader = kwargs.pop("loader", DbtLoader)
if not issubclass(loader, DbtLoader):
raise ConfigError("The loader must be a DbtLoader.")
return Config(
loader=loader,
model_defaults=model_defaults,
variables=variables or {},
**{
"default_gateway": profile.target_name if "gateways" not in kwargs else "",
"gateways": {
profile.target_name: GatewayConfig(
connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args),
state_connection=state_connection,
)
}, # type: ignore
**kwargs,
},
)
class DbtLoader(Loader):
def __init__(self, context: GenericContext, path: Path) -> None:
self._projects: t.List[Project] = []
self._macros_max_mtime: t.Optional[float] = None
super().__init__(context, path)
def load(self) -> LoadedProject:
self._projects = []
return super().load()
def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
macro_files = list(Path(self.config_path, "macros").glob("**/*.sql"))
for file in macro_files:
self._track_file(file)
jinja_macros = JinjaMacroRegistry()
for project in self._load_projects():
jinja_macros = jinja_macros.merge(project.context.jinja_macros)
jinja_macros.add_globals(project.context.jinja_globals)
return (macro.get_registry(), jinja_macros)
def _load_models(
self,
macros: MacroRegistry,
jinja_macros: JinjaMacroRegistry,
gateway: t.Optional[str],
audits: UniqueKeyDict[str, ModelAudit],
signals: UniqueKeyDict[str, signal],
) -> UniqueKeyDict[str, Model]:
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context))
return config.to_sqlmesh(
context,
audit_definitions=audits,
virtual_environment_mode=self.config.virtual_environment_mode,
)
for project in self._load_projects():
context = project.context.copy()
macros_max_mtime = self._macros_max_mtime
yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
project.context.project_root
)
cache = DbtLoader._Cache(self, project, macros_max_mtime, yaml_max_mtimes)
logger.debug("Converting models to sqlmesh")
# Now that config is rendered, create the sqlmesh models
for package in project.packages.values():
context.set_and_render_variables(package.variables, package.name)
package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}
for model in package_models.values():
sqlmesh_model = cache.get_or_load_models(
model.path, loader=lambda: [_to_sqlmesh(model, context)]
)[0]
models[sqlmesh_model.fqn] = sqlmesh_model
models.update(self._load_external_models(audits, cache))
return models
def _load_audits(
self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry
) -> UniqueKeyDict[str, Audit]:
audits: UniqueKeyDict = UniqueKeyDict("audits")
for project in self._load_projects():
context = project.context
logger.debug("Converting audits to sqlmesh")
for package in project.packages.values():
context.set_and_render_variables(package.variables, package.name)
for test in package.tests.values():
logger.debug("Converting '%s' to sqlmesh format", test.name)
audits[test.name] = test.to_sqlmesh(context)
return audits
def _load_projects(self) -> t.List[Project]:
if not self._projects:
target_name = self.context.selected_gateway
self._projects = []
project = Project.load(
DbtContext(
project_root=self.config_path,
target_name=target_name,
sqlmesh_config=self.config,
),
variables=self.config.variables,
)
self._projects.append(project)
if project.context.target.database != (self.context.default_catalog or ""):
raise ConfigError("Project default catalog does not match context default catalog")
for path in project.project_files:
self._track_file(path)
context = project.context
macros_mtimes: t.List[float] = []
for package_name, package in project.packages.items():
context.add_sources(package.sources)
context.add_seeds(package.seeds)
context.add_models(package.models)
macros_mtimes.extend(
[
self._path_mtimes[m.path]
for m in package.macros.values()
if m.path in self._path_mtimes
]
)
for package_name, macro_infos in context.manifest.all_macros.items():
context.add_macros(macro_infos, package=package_name)
self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None
return self._projects
def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
requirements, excluded_requirements = super()._load_requirements()
target_packages = ["dbt-core"]
for project in self._load_projects():
target_packages.append(f"dbt-{project.context.target.type}")
for target_package in target_packages:
if target_package in requirements or target_package in excluded_requirements:
continue
try:
requirements[target_package] = metadata.version(target_package)
except metadata.PackageNotFoundError:
from sqlmesh.core.console import get_console
get_console().log_warning(f"dbt package {target_package} is not installed.")
return requirements, excluded_requirements
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
project_names: t.Set[str] = set()
dialect = self.config.dialect
for project in self._load_projects():
context = project.context
for package_name, package in project.packages.items():
context.set_and_render_variables(package.variables, package_name)
on_run_start: t.List[str] = [
on_run_hook.sql
for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
]
on_run_end: t.List[str] = [
on_run_hook.sql
for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index)
]
if on_run_start or on_run_end:
dependencies = Dependencies()
for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
dependencies = dependencies.union(hook.dependencies)
statements_context = context.context_for_dependencies(dependencies)
jinja_registry = make_jinja_registry(
statements_context.jinja_macros, package_name, set(dependencies.macros)
)
jinja_registry.add_globals(statements_context.jinja_globals)
hooks_by_package_name[package_name] = EnvironmentStatements(
before_all=[
d.jinja_statement(stmt).sql(dialect=dialect)
for stmt in on_run_start or []
],
after_all=[
d.jinja_statement(stmt).sql(dialect=dialect)
for stmt in on_run_end or []
],
python_env={},
jinja_macros=jinja_registry,
project=package_name,
)
project_names.add(package_name)
return [
statements
for _, statements in sorted(
hooks_by_package_name.items(),
key=lambda item: 0 if item[0] in project_names else 1,
)
]
def _compute_yaml_max_mtime_per_subfolder(
self, root: Path, visited: t.Optional[t.Set[Path]] = None
) -> t.Dict[Path, float]:
root = root.resolve()
visited = visited or set()
if not root.is_dir() or root in visited:
return {}
visited.add(root)
result = {}
max_mtime: t.Optional[float] = None
for nested in root.iterdir():
try:
if nested.is_dir():
result.update(
self._compute_yaml_max_mtime_per_subfolder(nested, visited=visited)
)
elif nested.suffix.lower() in (".yaml", ".yml"):
yaml_mtime = self._path_mtimes.get(nested)
if yaml_mtime:
max_mtime = (
max(max_mtime, yaml_mtime) if max_mtime is not None else yaml_mtime
)
except PermissionError:
pass
if max_mtime is not None:
result[root] = max_mtime
return result
class _Cache(CacheBase):
MAX_ENTRY_NAME_LENGTH = 200
def __init__(
self,
loader: DbtLoader,
project: Project,
macros_max_mtime: t.Optional[float],
yaml_max_mtimes: t.Dict[Path, float],
):
self._loader = loader
self._project = project
self._macros_max_mtime = macros_max_mtime
self._yaml_max_mtimes = yaml_max_mtimes
target = t.cast(TargetConfig, project.context.target)
cache_dir = loader.context.cache_dir / target.name
self._model_cache = ModelCache(cache_dir)
def get_or_load_models(
self, target_path: Path, loader: t.Callable[[], t.List[Model]]
) -> t.List[Model]:
models = self._model_cache.get_or_load(
self._cache_entry_name(target_path),
self._cache_entry_id(target_path),
loader=loader,
)
for model in models:
model._path = target_path
return models
def put(self, models: t.List[Model], path: Path) -> bool:
return self._model_cache.put(
models,
self._cache_entry_name(path),
self._cache_entry_id(path),
)
def get(self, path: Path) -> t.List[Model]:
return self._model_cache.get(
self._cache_entry_name(path),
self._cache_entry_id(path),
)
def _cache_entry_name(self, target_path: Path) -> str:
try:
path_for_name = target_path.absolute().relative_to(
self._project.context.project_root.absolute()
)
except ValueError:
path_for_name = target_path
name = "__".join(path_for_name.parts).replace(path_for_name.suffix, "")
if len(name) > self.MAX_ENTRY_NAME_LENGTH:
return name[len(name) - self.MAX_ENTRY_NAME_LENGTH :]
return name
def _cache_entry_id(self, target_path: Path) -> str:
max_mtime = self._max_mtime_for_path(target_path)
return "__".join(
[
str(int(max_mtime)) if max_mtime is not None else "na",
self._loader.config.fingerprint,
]
)
def _max_mtime_for_path(self, target_path: Path) -> t.Optional[float]:
project_root = self._project.context.project_root
try:
target_path.absolute().relative_to(project_root.absolute())
except ValueError:
return None
mtimes = [
self._loader._path_mtimes.get(target_path),
self._loader._path_mtimes.get(self._project.profile.path),
# FIXME: take into account which macros are actually referenced in the target model.
self._macros_max_mtime,
]
cursor = target_path
while cursor != project_root:
cursor = cursor.parent
mtimes.append(self._yaml_max_mtimes.get(cursor))
non_null_mtimes = [t for t in mtimes if t is not None]
return max(non_null_mtimes) if non_null_mtimes else None