Skip to content

Commit 2c95184

Browse files
authored
Fix: make macro function argument resolution more robust (#2604)
* Fix: refactor macro function argument resolution so it's more robust * Fix mypy issue * Throw MacroEvalError instead of TypeError, add tests * Update comment * Fix typo in macro docs (except -> except_) * Incorporate PR feedback
1 parent 404c68e commit 2c95184

File tree

4 files changed

+80
-30
lines changed

4 files changed

+80
-30
lines changed

docs/concepts/macros/sqlmesh_macros.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,12 +600,12 @@ If the column data types are known, the resulting query `CAST`s columns to their
600600

601601
- `relation`: The relation/table whose columns are being selected
602602
- `alias` (optional): The alias of the relation (if it has one)
603-
- `except` (optional): A list of columns to exclude
603+
- `except_` (optional): A list of columns to exclude
604604
- `prefix` (optional): A string to use as a prefix for all selected column names
605605
- `suffix` (optional): A string to use as a suffix for all selected column names
606606
- `quote_identifiers` (optional): Whether to quote the resulting identifiers, defaults to true
607607

608-
Like all SQLMesh macro functions, omitting an argument when calling `@STAR` requires passing all subsequent arguments with their name and the special `:=` keyword operator. For example, we might omit the `alias` argument with `@STAR(foo, except := [c])`. Learn more about macro function arguments [below](#positional-and-keyword-arguments).
608+
Like all SQLMesh macro functions, omitting an argument when calling `@STAR` requires passing all subsequent arguments with their name and the special `:=` keyword operator. For example, we might omit the `alias` argument with `@STAR(foo, except_ := [c])`. Learn more about macro function arguments [below](#positional-and-keyword-arguments).
609609

610610
As a `@STAR` example, consider the following query:
611611

@@ -635,15 +635,15 @@ FROM foo AS bar
635635
Note these aspects of the rendered query:
636636
- Each column is `CAST` to its data type in the table `foo` (e.g., `a` to `TEXT`)
637637
- Each column selection uses the alias `bar` (e.g., `"bar"."a"`)
638-
- Column `c` is not present because it was passed to `@STAR`'s `except` argument
638+
- Column `c` is not present because it was passed to `@STAR`'s `except_` argument
639639
- Each column alias is prefixed with `baz_` and suffixed with `_qux` (e.g., `"baz_a_qux"`)
640640

641641
Now consider a more complex example that provides different prefixes to `a` and `b` than to `d` and includes an explicit column `my_column`:
642642

643643
```sql linenums="1"
644644
SELECT
645-
@STAR(foo, bar, except=[c, d], 'ab_pre_'),
646-
@STAR(foo, bar, except=[a, b, c], 'd_pre_'),
645+
@STAR(foo, bar, except_=[c, d], 'ab_pre_'),
646+
@STAR(foo, bar, except_=[a, b, c], 'd_pre_'),
647647
my_column
648648
FROM foo AS bar
649649
```
@@ -661,7 +661,7 @@ FROM foo AS bar
661661

662662
Note these aspects of the rendered query:
663663
- Columns `a` and `b` have the prefix `"ab_pre_"` , while column `d` has the prefix `"d_pre_"`
664-
- Column `c` is not present because it was passed to the `except` argument in both `@STAR` calls
664+
- Column `c` is not present because it was passed to the `except_` argument in both `@STAR` calls
665665
- `my_column` is present in the query
666666

667667
### @GENERATE_SURROGATE_KEY

sqlmesh/core/macros.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -183,35 +183,39 @@ def send(
183183
if not callable(func):
184184
raise SQLMeshError(f"Macro '{name}' does not exist.")
185185

186+
try:
187+
# Bind the macro's actual parameters to its formal parameters
188+
sig = inspect.signature(func)
189+
bound = sig.bind(self, *args, **kwargs)
190+
bound.apply_defaults()
191+
except Exception as e:
192+
print_exception(e, self.python_env)
193+
raise MacroEvalError("Error trying to eval macro.") from e
194+
186195
try:
187196
annotations = t.get_type_hints(func)
188197
except NameError: # forward references aren't handled
189198
annotations = {}
190199

200+
# If the macro is annotated, we try coerce the actual parameters to the corresponding types
191201
if annotations:
192-
spec = inspect.getfullargspec(func)
193-
callargs = inspect.getcallargs(func, self, *args, **kwargs)
194-
new_args: t.List[t.Any] = []
195-
196-
for arg, value in callargs.items():
202+
for arg, value in bound.arguments.items():
197203
typ = annotations.get(arg)
198-
199-
if value is self:
204+
if not typ:
200205
continue
201-
if arg == spec.varargs:
202-
new_args.extend(self._coerce(v, typ) for v in value)
203-
elif arg == spec.varkw:
204-
for k, v in value.items():
205-
kwargs[k] = self._coerce(v, typ)
206-
elif arg in kwargs:
207-
kwargs[arg] = self._coerce(value, typ)
208-
else:
209-
new_args.append(self._coerce(value, typ))
210206

211-
args = new_args # type: ignore
207+
# Changes to bound.arguments will reflect in bound.args and bound.kwargs
208+
# https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments
209+
param = sig.parameters[arg]
210+
if param.kind is inspect.Parameter.VAR_POSITIONAL:
211+
bound.arguments[arg] = tuple(self._coerce(v, typ) for v in value)
212+
elif param.kind is inspect.Parameter.VAR_KEYWORD:
213+
bound.arguments[arg] = {k: self._coerce(v, typ) for k, v in value.items()}
214+
else:
215+
bound.arguments[arg] = self._coerce(value, typ)
212216

213217
try:
214-
return func(self, *args, **kwargs)
218+
return func(*bound.args, **bound.kwargs)
215219
except Exception as e:
216220
print_exception(e, self.python_env)
217221
raise MacroEvalError("Error trying to eval macro.") from e
@@ -337,7 +341,8 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
337341
else:
338342
if kwargs:
339343
raise MacroEvalError(
340-
f"Positional argument cannot follow keyword argument.\n {func.sql(dialect=self.dialect)} at '{self._path}'"
344+
"Positional argument cannot follow keyword argument.\n "
345+
f"{func.sql(dialect=self.dialect)} at '{self._path}'"
341346
)
342347

343348
args.append(e)
@@ -479,7 +484,7 @@ def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.A
479484
return expr
480485
if issubclass(base, exp.Expression):
481486
d = Dialect.get_or_raise(self.dialect)
482-
into = base if base in d.parser().EXPRESSION_PARSERS else None
487+
into = base if base in d.parser_class.EXPRESSION_PARSERS else None
483488
if into is None:
484489
if isinstance(expr, exp.Literal):
485490
coerced = parse_one(expr.this)
@@ -809,7 +814,7 @@ def star(
809814
>>> from sqlglot import parse_one, exp
810815
>>> from sqlglot.schema import MappingSchema
811816
>>> from sqlmesh.core.macros import MacroEvaluator
812-
>>> sql = "SELECT @STAR(foo, bar, [c], 'baz_') FROM foo AS bar"
817+
>>> sql = "SELECT @STAR(foo, bar, except_ := [c], prefix := 'baz_') FROM foo AS bar"
813818
>>> MacroEvaluator(schema=MappingSchema({"foo": {"a": exp.DataType.build("string"), "b": exp.DataType.build("string"), "c": exp.DataType.build("string"), "d": exp.DataType.build("int")}})).transform(parse_one(sql)).sql()
814819
'SELECT CAST("bar"."a" AS TEXT) AS "baz_a", CAST("bar"."b" AS TEXT) AS "baz_b", CAST("bar"."d" AS INT) AS "baz_d" FROM foo AS bar'
815820
"""

sqlmesh/dbt/relation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
if DBT_VERSION < (1, 8):
55
from dbt.contracts.relation import * # type: ignore # noqa: F403
66
else:
7-
from dbt.adapters.contracts.relation import * # noqa: F403
7+
from dbt.adapters.contracts.relation import * # type: ignore # noqa: F403

tests/core/test_macros.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ def suffix_idents_2(evaluator: MacroEvaluator, items: t.Tuple[str, ...], suffix:
7676
def stamped(evaluator, query: exp.Select) -> exp.Subquery:
7777
return query.select(exp.Literal.string("2024-01-01").as_("stamp")).subquery()
7878

79+
@macro()
80+
def test_arg_resolution(evaluator, pos_only, /, a1, *, a2=1, **rest):
81+
return 1
82+
83+
@macro()
84+
def test_default_arg_coercion(
85+
evaluator: MacroEvaluator,
86+
a1: int = 1,
87+
a2: int = exp.Literal.number(2), # type: ignore
88+
):
89+
return sum([a1, a2])
90+
7991
return MacroEvaluator(
8092
"hive",
8193
{"test": Executable(name="test", payload="def test(_):\n return 'test'")},
@@ -440,6 +452,11 @@ def test_ast_correctness(macro_evaluator):
440452
"SELECT 'a1-b1-c2-d:d1-e:e2'",
441453
{},
442454
),
455+
(
456+
"""select @TEST_DEFAULT_ARG_COERCION()""",
457+
"SELECT 3",
458+
{},
459+
),
443460
],
444461
)
445462
def test_macro_functions(macro_evaluator: MacroEvaluator, assert_exp_eq, sql, expected, args):
@@ -537,6 +554,34 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq):
537554
)
538555

539556

540-
def test_positional_follows_kwargs(macro_evaluator: MacroEvaluator):
557+
def test_positional_follows_kwargs(macro_evaluator):
541558
with pytest.raises(MacroEvalError, match="Positional argument cannot follow"):
542-
macro_evaluator.evaluate(parse_one("@repeated(x, multi := True, 3)")) # type: ignore
559+
macro_evaluator.evaluate(parse_one("@repeated(x, multi := True, 3)"))
560+
561+
562+
def test_macro_parameter_resolution(macro_evaluator):
563+
with pytest.raises(MacroEvalError) as e:
564+
macro_evaluator.evaluate(parse_one("@test_arg_resolution()"))
565+
assert str(e.value.__cause__) == "missing a required argument: 'pos_only'"
566+
567+
with pytest.raises(MacroEvalError) as e:
568+
macro_evaluator.evaluate(parse_one("@test_arg_resolution(a1 := 1)"))
569+
assert str(e.value.__cause__) == "missing a required argument: 'pos_only'"
570+
571+
with pytest.raises(MacroEvalError) as e:
572+
macro_evaluator.evaluate(parse_one("@test_arg_resolution(1)"))
573+
assert str(e.value.__cause__) == "missing a required argument: 'a1'"
574+
575+
with pytest.raises(MacroEvalError) as e:
576+
macro_evaluator.evaluate(parse_one("@test_arg_resolution(1, a2 := 2)"))
577+
assert str(e.value.__cause__) == "missing a required argument: 'a1'"
578+
579+
with pytest.raises(MacroEvalError) as e:
580+
macro_evaluator.evaluate(parse_one("@test_arg_resolution(pos_only := 1)"))
581+
assert str(e.value.__cause__) == (
582+
"'pos_only' parameter is positional only, but was passed as a keyword"
583+
)
584+
585+
with pytest.raises(MacroEvalError) as e:
586+
macro_evaluator.evaluate(parse_one("@test_arg_resolution(1, 2, 3)"))
587+
assert str(e.value.__cause__) == "too many positional arguments"

0 commit comments

Comments
 (0)