|
11 | 11 | from sqlmesh.core.engine_adapter import EngineAdapter |
12 | 12 | from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping |
13 | 13 | from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError |
14 | | -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference |
| 14 | +from sqlmesh.utils.jinja import JinjaMacroRegistry |
15 | 15 |
|
16 | 16 | if t.TYPE_CHECKING: |
17 | 17 | import agate |
@@ -98,19 +98,49 @@ def quote(self, identifier: str) -> str: |
98 | 98 | def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable: |
99 | 99 | """Returns a dialect-specific version of a macro with the given name.""" |
100 | 100 | target_type = self.jinja_globals["target"]["type"] |
101 | | - references_to_try = [ |
102 | | - MacroReference(package=f"{package}_{target_type}", name=f"{target_type}__{name}"), |
103 | | - MacroReference(package=package, name=f"{target_type}__{name}"), |
104 | | - MacroReference(package=package, name=f"default__{name}"), |
| 101 | + macro_suffix = f"__{name}" |
| 102 | + |
| 103 | + def _relevance(package_name_pair: t.Tuple[t.Optional[str], str]) -> t.Tuple[int, int]: |
| 104 | + """Lower scores more relevant.""" |
| 105 | + macro_package, macro_name = package_name_pair |
| 106 | + |
| 107 | + package_score = 0 if macro_package == package else 1 |
| 108 | + name_score = 1 |
| 109 | + |
| 110 | + if macro_name.startswith("default"): |
| 111 | + name_score = 2 |
| 112 | + elif macro_name.startswith(target_type): |
| 113 | + name_score = 0 |
| 114 | + |
| 115 | + return name_score, package_score |
| 116 | + |
| 117 | + jinja_env = self.jinja_macros.build_environment(**self.jinja_globals).globals |
| 118 | + packages_to_check: t.List[t.Optional[str]] = [ |
| 119 | + package, |
| 120 | + *(k for k in jinja_env if k.startswith("dbt")), |
105 | 121 | ] |
| 122 | + candidates = {} |
| 123 | + for macro_package in packages_to_check: |
| 124 | + macros = jinja_env.get(macro_package, {}) if macro_package else jinja_env |
| 125 | + if not isinstance(macros, dict): |
| 126 | + continue |
| 127 | + candidates.update( |
| 128 | + { |
| 129 | + (macro_package, macro_name): macro_callable |
| 130 | + for macro_name, macro_callable in macros.items() |
| 131 | + if macro_name.endswith(macro_suffix) |
| 132 | + } |
| 133 | + ) |
106 | 134 |
|
107 | | - for reference in references_to_try: |
108 | | - macro_callable = self.jinja_macros.build_macro(reference, **self.jinja_globals) |
109 | | - if macro_callable is not None: |
110 | | - return macro_callable |
| 135 | + if candidates: |
| 136 | + sorted_candidates = sorted(candidates, key=_relevance) |
| 137 | + return candidates[sorted_candidates[0]] |
111 | 138 |
|
112 | 139 | raise ConfigError(f"Macro '{name}', package '{package}' was not found.") |
113 | 140 |
|
| 141 | + def type(self) -> str: |
| 142 | + return self.project_dialect or "" |
| 143 | + |
114 | 144 |
|
115 | 145 | class ParsetimeAdapter(BaseAdapter): |
116 | 146 | def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: |
|
0 commit comments