Skip to content

Commit a36ad90

Browse files
authored
Feat: Support macros in the MODEL statement (#2499)
1 parent 95dfc02 commit a36ad90

File tree

3 files changed

+131
-8
lines changed

3 files changed

+131
-8
lines changed

sqlmesh/core/macros.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
235235
"""
236236
mapping = {}
237237

238-
for k, v in chain(self.locals.items(), local_variables.items()):
238+
variables = self.locals.get(c.SQLMESH_VARS, {})
239+
240+
for k, v in chain(variables.items(), self.locals.items(), local_variables.items()):
239241
# try to convert all variables into sqlglot expressions
240242
# because they're going to be converted into strings in sql
241243
# we use bare Exception instead of ValueError because there's

sqlmesh/core/model/definition.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from sqlmesh.core import constants as c
2424
from sqlmesh.core import dialect as d
25-
from sqlmesh.core.macros import MacroRegistry, macro
25+
from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate, macro
2626
from sqlmesh.core.model.common import expression_validator
2727
from sqlmesh.core.model.kind import ModelKindName, SeedKind
2828
from sqlmesh.core.model.meta import ModelMeta
@@ -1435,6 +1435,38 @@ def load_sql_based_model(
14351435
path,
14361436
)
14371437

1438+
unrendered_signals = None
1439+
for prop in meta.expressions:
1440+
if prop.name.lower() == "signals":
1441+
unrendered_signals = prop.args.get("value")
1442+
1443+
meta_python_env = _python_env(
1444+
expressions=meta,
1445+
jinja_macro_references=None,
1446+
module_path=module_path,
1447+
macros=macros or macro.get_registry(),
1448+
variables=variables,
1449+
path=path,
1450+
)
1451+
meta_renderer = ExpressionRenderer(
1452+
meta,
1453+
dialect,
1454+
[],
1455+
path=path,
1456+
jinja_macro_registry=jinja_macros,
1457+
python_env=meta_python_env,
1458+
default_catalog=default_catalog,
1459+
quote_identifiers=False,
1460+
)
1461+
rendered_meta_exprs = meta_renderer.render()
1462+
if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1:
1463+
raise_config_error(
1464+
f"Invalid MODEL statement:\n{meta.sql(dialect=dialect, pretty=True)}",
1465+
path,
1466+
)
1467+
raise
1468+
rendered_meta = rendered_meta_exprs[0]
1469+
14381470
# Extract the query and any pre/post statements
14391471
query_or_seed_insert, pre_statements, post_statements = _split_sql_model_statements(
14401472
expressions[1:], path
@@ -1443,11 +1475,16 @@ def load_sql_based_model(
14431475
meta_fields: t.Dict[str, t.Any] = {
14441476
"dialect": dialect,
14451477
"description": (
1446-
"\n".join(comment.strip() for comment in meta.comments) if meta.comments else None
1478+
"\n".join(comment.strip() for comment in rendered_meta.comments)
1479+
if rendered_meta.comments
1480+
else None
14471481
),
1448-
**{prop.name.lower(): prop.args.get("value") for prop in meta.expressions},
1482+
**{prop.name.lower(): prop.args.get("value") for prop in rendered_meta.expressions},
14491483
**kwargs,
14501484
}
1485+
if unrendered_signals:
1486+
# Signals must remain unrendered, so that they can be rendered later at evaluation runtime.
1487+
meta_fields["signals"] = unrendered_signals
14511488

14521489
name = meta_fields.pop("name", "")
14531490
if not name:
@@ -1465,8 +1502,8 @@ def load_sql_based_model(
14651502
)
14661503

14671504
jinja_macros = (jinja_macros or JinjaMacroRegistry()).trim(jinja_macro_references)
1468-
for macro in jinja_macros.root_macros.values():
1469-
used_variables.update(extract_macro_references_and_variables(macro.definition)[1])
1505+
for jinja_macro in jinja_macros.root_macros.values():
1506+
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
14701507

14711508
common_kwargs = dict(
14721509
pre_statements=pre_statements,
@@ -1867,7 +1904,7 @@ def _python_env(
18671904
expressions = ensure_list(expressions)
18681905
for expression in expressions:
18691906
if not isinstance(expression, d.Jinja):
1870-
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar):
1907+
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier):
18711908
if macro_func_or_var.__class__ is d.MacroFunc:
18721909
name = macro_func_or_var.this.name.lower()
18731910
if name in macros:
@@ -1888,6 +1925,15 @@ def _python_env(
18881925
used_macros[name] = macros[name]
18891926
elif name in variables:
18901927
used_variables.add(name)
1928+
elif (
1929+
isinstance(macro_func_or_var, exp.Identifier) and "@" in macro_func_or_var.this
1930+
):
1931+
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
1932+
macro_func_or_var.this
1933+
):
1934+
var_name = braced_identifier or identifier
1935+
if var_name in variables:
1936+
used_variables.add(var_name)
18911937

18921938
for macro_ref in jinja_macro_references or set():
18931939
if macro_ref.package is None and macro_ref.name in macros:

tests/core/test_model.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ def test_render_definition():
938938
owner owner_name,
939939
dialect spark,
940940
kind INCREMENTAL_BY_TIME_RANGE (
941-
time_column (a, 'yyyymmdd')
941+
time_column (`a`, 'yyyymmdd')
942942
),
943943
storage_format iceberg,
944944
partitioned_by `a`,
@@ -3516,6 +3516,56 @@ def test_named_variable_macros() -> None:
35163516
)
35173517

35183518

3519+
def test_variables_in_templates() -> None:
3520+
model = load_sql_based_model(
3521+
parse(
3522+
"""
3523+
MODEL(name sushi.test_gateway_macro);
3524+
@DEF(overridden_var, overridden_value);
3525+
SELECT 'gateway' AS col_@gateway, 'test_var_a' AS @{test_var_a}_col, 'overridden_var' AS col_@{overridden_var}_col
3526+
"""
3527+
),
3528+
variables={
3529+
c.GATEWAY: "in_memory",
3530+
"test_var_a": "test_value",
3531+
"test_var_unused": "unused",
3532+
"overridden_var": "initial_value",
3533+
},
3534+
)
3535+
3536+
assert model.python_env[c.SQLMESH_VARS] == Executable.value(
3537+
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}
3538+
)
3539+
assert (
3540+
model.render_query_or_raise().sql()
3541+
== "SELECT 'gateway' AS \"col_in_memory\", 'test_var_a' AS \"test_value_col\", 'overridden_var' AS \"col_overridden_value_col\""
3542+
)
3543+
3544+
model = load_sql_based_model(
3545+
parse(
3546+
"""
3547+
MODEL(name sushi.test_gateway_macro);
3548+
@DEF(overridden_var, overridden_value);
3549+
SELECT 'combo' AS col_@{test_var_a}_@{overridden_var}_col_@gateway
3550+
"""
3551+
),
3552+
variables={
3553+
c.GATEWAY: "in_memory",
3554+
"test_var_a": "test_value",
3555+
"test_var_unused": "unused",
3556+
"overridden_var": "initial_value",
3557+
},
3558+
)
3559+
3560+
assert model.python_env[c.SQLMESH_VARS] == Executable.value(
3561+
{c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}
3562+
)
3563+
assert (
3564+
model.render_query_or_raise().sql()
3565+
== "SELECT 'combo' AS \"col_test_value_overridden_value_col_in_memory\""
3566+
)
3567+
3568+
35193569
def test_variables_jinja():
35203570
expressions = parse(
35213571
"""
@@ -3759,3 +3809,28 @@ def test_this_model() -> None:
37593809

37603810
assert model.render_pre_statements()[0].sql() == """COPY "db"."table" TO 'a'"""
37613811
assert model.render_post_statements()[0].sql() == """COPY "db"."table" TO 'b'"""
3812+
3813+
3814+
def test_macros_in_model_statement(sushi_context, assert_exp_eq):
3815+
expressions = d.parse(
3816+
"""
3817+
MODEL (
3818+
name @{gateway}__@{gateway}.test_model,
3819+
kind INCREMENTAL_BY_TIME_RANGE (
3820+
time_column @{time_column}
3821+
3822+
),
3823+
start @IF(@gateway = 'test_gateway', '2023-01-01', '2024-01-02')
3824+
);
3825+
3826+
SELECT a, b UNION SELECT c, c
3827+
"""
3828+
)
3829+
3830+
model = load_sql_based_model(
3831+
expressions, variables={"gateway": "test_gateway", "time_column": "a"}
3832+
)
3833+
assert model.name == "test_gateway__test_gateway.test_model"
3834+
assert model.time_column
3835+
assert model.time_column.column == exp.column("a", quoted=True)
3836+
assert model.start == "2023-01-01"

0 commit comments

Comments
 (0)