Skip to content

Commit ca05f9b

Browse files
authored
Fix: Improve the dbt adapter dispatch resolution when the package name is not specified (#2841)
1 parent 437edcc commit ca05f9b

File tree

5 files changed

+48
-9
lines changed

5 files changed

+48
-9
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sqlmesh.core.engine_adapter import EngineAdapter
1212
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping
1313
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
14-
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference
14+
from sqlmesh.utils.jinja import JinjaMacroRegistry
1515

1616
if t.TYPE_CHECKING:
1717
import agate
@@ -98,19 +98,49 @@ def quote(self, identifier: str) -> str:
9898
def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
9999
"""Returns a dialect-specific version of a macro with the given name."""
100100
target_type = self.jinja_globals["target"]["type"]
101-
references_to_try = [
102-
MacroReference(package=f"{package}_{target_type}", name=f"{target_type}__{name}"),
103-
MacroReference(package=package, name=f"{target_type}__{name}"),
104-
MacroReference(package=package, name=f"default__{name}"),
101+
macro_suffix = f"__{name}"
102+
103+
def _relevance(package_name_pair: t.Tuple[t.Optional[str], str]) -> t.Tuple[int, int]:
104+
"""Lower scores more relevant."""
105+
macro_package, macro_name = package_name_pair
106+
107+
package_score = 0 if macro_package == package else 1
108+
name_score = 1
109+
110+
if macro_name.startswith("default"):
111+
name_score = 2
112+
elif macro_name.startswith(target_type):
113+
name_score = 0
114+
115+
return name_score, package_score
116+
117+
jinja_env = self.jinja_macros.build_environment(**self.jinja_globals).globals
118+
packages_to_check: t.List[t.Optional[str]] = [
119+
package,
120+
*(k for k in jinja_env if k.startswith("dbt")),
105121
]
122+
candidates = {}
123+
for macro_package in packages_to_check:
124+
macros = jinja_env.get(macro_package, {}) if macro_package else jinja_env
125+
if not isinstance(macros, dict):
126+
continue
127+
candidates.update(
128+
{
129+
(macro_package, macro_name): macro_callable
130+
for macro_name, macro_callable in macros.items()
131+
if macro_name.endswith(macro_suffix)
132+
}
133+
)
106134

107-
for reference in references_to_try:
108-
macro_callable = self.jinja_macros.build_macro(reference, **self.jinja_globals)
109-
if macro_callable is not None:
110-
return macro_callable
135+
if candidates:
136+
sorted_candidates = sorted(candidates, key=_relevance)
137+
return candidates[sorted_candidates[0]]
111138

112139
raise ConfigError(f"Macro '{name}', package '{package}' was not found.")
113140

141+
def type(self) -> str:
142+
return self.project_dialect or ""
143+
114144

115145
class ParsetimeAdapter(BaseAdapter):
116146
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:

sqlmesh/dbt/builtin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def raise_compiler_error(self, msg: str) -> None:
3737

3838
raise CompilationException(msg)
3939

40+
def raise_not_implemented(self, msg: str) -> None:
41+
raise NotImplementedError(msg)
42+
4043
def warn(self, msg: str) -> str:
4144
logger.warning(msg)
4245
return ""

tests/dbt/test_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Calla
164164
context = sushi_test_project.context
165165
renderer = runtime_renderer(context)
166166
assert renderer("{{ adapter.dispatch('current_engine', 'customers')() }}") == "duckdb"
167+
assert renderer("{{ adapter.dispatch('current_timestamp')() }}") == "now()"
168+
assert renderer("{{ adapter.dispatch('current_timestamp', 'dbt')() }}") == "now()"
167169

168170
with pytest.raises(ConfigError, match=r"Macro 'current_engine'.*was not found."):
169171
renderer("{{ adapter.dispatch('current_engine')() }}")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
{% macro current_engine() %}{{ return(adapter.dispatch('current_engine')) }}{% endmacro %}
2+
13
{% macro default__current_engine() %}default{% endmacro %}
24

35
{% macro duckdb__current_engine() %}duckdb{% endmacro %}

tests/fixtures/dbt/sushi_test/packages/customers/models/customer_revenue_by_day.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
)
88
}}
99

10+
{{ log(current_engine()) }}
11+
1012
WITH order_total AS (
1113
SELECT
1214
oi.order_id AS order_id,

0 commit comments

Comments
 (0)