Skip to content

Commit e4b2817

Browse files
authored
Fix: Make dbt adapter macros available in the local scope (#3219)
1 parent 1fef3e8 commit e4b2817

File tree

4 files changed

+64
-4
lines changed

4 files changed

+64
-4
lines changed

sqlmesh/dbt/manifest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,19 @@ def _load_macros(self) -> None:
167167
path=Path(macro.original_file_path),
168168
)
169169

170+
# This is a workaround for dbt adapter macros (eg. "spark__dateadd") whcih are expected to be
171+
# available in the global scope regardless of the package they came from.
172+
adapter_macro_names = {
173+
name[name.find("__") + 2 :]
174+
for name in self._macros_per_package.get("dbt", {})
175+
if "__" in name
176+
}
177+
for macros in self._macros_per_package.values():
178+
for name, macro_config in macros.items():
179+
pos = name.find("__")
180+
if pos > 0 and name[pos + 2 :] in adapter_macro_names:
181+
macro_config.info.is_top_level = True
182+
170183
def _load_tests(self) -> None:
171184
for node in self._manifest.nodes.values():
172185
if node.resource_type != "test":

sqlmesh/utils/jinja.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class MacroInfo(PydanticModel):
5252

5353
definition: str
5454
depends_on: t.List[MacroReference]
55+
is_top_level: bool = False
5556

5657

5758
class MacroReturnVal(Exception):
@@ -298,10 +299,11 @@ def build_environment(self, **kwargs: t.Any) -> Environment:
298299

299300
package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict)
300301
for package_name, macros in self.packages.items():
301-
for macro_name in macros:
302-
package_macros[package_name][macro_name] = self._MacroWrapper(
303-
macro_name, package_name, self, context
304-
)
302+
for macro_name, macro in macros.items():
303+
macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context)
304+
package_macros[package_name][macro_name] = macro_wrapper
305+
if macro.is_top_level and macro_name not in root_macros:
306+
root_macros[macro_name] = macro_wrapper
305307

306308
if self.root_package_name is not None:
307309
package_macros[self.root_package_name].update(root_macros)

tests/dbt/test_manifest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,30 @@ def test_source_meta_external_location():
206206
)
207207
assert relation.identifier == "items"
208208
assert relation.render() == "read_parquet('path/to/external/items.parquet')"
209+
210+
211+
@pytest.mark.xdist_group("dbt_manifest")
212+
def test_top_level_dbt_adapter_macros():
213+
project_path = Path("tests/fixtures/dbt/sushi_test")
214+
profile = Profile.load(DbtContext(project_path))
215+
216+
helper = ManifestHelper(
217+
project_path,
218+
project_path,
219+
"sushi",
220+
profile.target,
221+
variable_overrides={"start": "2020-01-01"},
222+
)
223+
224+
# Adapter macros must be marked as top-level
225+
dbt_macros = helper.macros("dbt")
226+
dbt_duckdb_macros = helper.macros("dbt_duckdb")
227+
assert dbt_macros["default__dateadd"].info.is_top_level
228+
assert dbt_macros["default__datediff"].info.is_top_level
229+
assert dbt_duckdb_macros["duckdb__datediff"].info.is_top_level
230+
assert dbt_duckdb_macros["duckdb__dateadd"].info.is_top_level
231+
232+
# Project dispatch macros should not be marked as top-level
233+
customers_macros = helper.macros("customers")
234+
assert not customers_macros["default__current_engine"].info.is_top_level
235+
assert not customers_macros["duckdb__current_engine"].info.is_top_level

tests/utils/test_jinja.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,21 @@ def test_find_call_names():
280280
("package", "package_macro"),
281281
("'stringval'", "function"),
282282
]
283+
284+
285+
def test_dbt_adapter_macro_scope():
286+
package_a = """
287+
{% macro spark__macro_a() %}
288+
macro_a
289+
{% endmacro %}"""
290+
291+
extractor = MacroExtractor()
292+
registry = JinjaMacroRegistry()
293+
294+
macros = extractor.extract(package_a)
295+
macros["spark__macro_a"].is_top_level = True
296+
297+
registry.add_macros(macros, package="package_a")
298+
299+
rendered = registry.build_environment().from_string("{{ spark__macro_a() }}").render()
300+
assert rendered.strip() == "macro_a"

0 commit comments

Comments
 (0)