Skip to content

Commit ccd1cc6

Browse files
authored
Feature: add the if macro. (#796)
1 parent e3aa4eb commit ccd1cc6

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ it-test: core-it-test airflow-it-test-with-env
2828

2929
it-test-docker: core-it-test airflow-it-test-docker-with-env
3030

31-
test: unit-test it-test doc-test
31+
test: unit-test doc-test it-test
3232

3333
package:
3434
pip3 install wheel && python3 setup.py sdist bdist_wheel

sqlmesh/core/macros.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,18 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
188188
return [exp.convert(item) for item in result if item is not None]
189189
return exp.convert(result)
190190

191-
def eval_expression(self, node: exp.Expression) -> t.Any:
191+
def eval_expression(self, node: t.Any) -> t.Any:
192192
"""Converts a SQLGlot expression into executable Python code and evals it.
193193
194+
If the node is not an expression, it will simply be returned.
195+
194196
Args:
195197
node: expression
196198
Returns:
197199
The return value of the evaled Python Code.
198200
"""
201+
if not isinstance(node, exp.Expression):
202+
return node
199203
code = node.sql()
200204
try:
201205
code = self.generator.generate(node)
@@ -333,6 +337,31 @@ def each(
333337
return [item for item in map(func, ensure_collection(items)) if item is not None]
334338

335339

340+
@macro("IF")
341+
def if_(
342+
evaluator: MacroEvaluator,
343+
condition: t.Any,
344+
true: t.Any,
345+
false: t.Any = None,
346+
) -> t.Any:
347+
"""Evaluates a given condition and returns the second argument if true or else the third argument.
348+
349+
If false is not passed in, the default return value will be None.
350+
351+
Example:
352+
>>> from sqlglot import parse_one
353+
>>> from sqlmesh.core.macros import MacroEvaluator
354+
>>> MacroEvaluator().transform(parse_one("@IF('a' = 1, a, b)")).sql()
355+
'b'
356+
357+
>>> MacroEvaluator().transform(parse_one("@IF('a' = 1, a)"))
358+
"""
359+
360+
if evaluator.eval_expression(condition):
361+
return true
362+
return false
363+
364+
336365
@macro("REDUCE")
337366
def reduce_(evaluator: MacroEvaluator, *args: t.Any) -> t.Any:
338367
"""Iterates through items applying provided function that takes two arguments

tests/core/test_macros.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,26 @@ def test_ast_correctness(macro_evaluator):
168168
"SELECT * FROM city",
169169
{"do_order": False},
170170
),
171+
(
172+
"""select @if(TRUE, 1, 0)""",
173+
"SELECT 1",
174+
{},
175+
),
176+
(
177+
"""select @if(FALSE, 1, 0)""",
178+
"SELECT 0",
179+
{},
180+
),
181+
(
182+
"""select @if(1 > 0, 1, 0)""",
183+
"SELECT 1",
184+
{},
185+
),
186+
(
187+
"""select @if('a' = 'b', c), d""",
188+
"SELECT d",
189+
{},
190+
),
171191
],
172192
)
173193
def test_macro_functions(macro_evaluator, assert_exp_eq, sql, expected, args):

0 commit comments

Comments
 (0)