Skip to content

Commit 31bea82

Browse files
authored
fix: render signals properly for variables (#3873)
1 parent a1feec9 commit 31bea82

File tree

5 files changed

+46
-6
lines changed

5 files changed

+46
-6
lines changed

.circleci/test_migration.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ make install-dev
3636

3737
# Migrate and make sure the diff is empty
3838
pushd $SUSHI_DIR
39-
sqlmesh --gateway $GATEWAY_NAME migrate
40-
sqlmesh --gateway $GATEWAY_NAME diff prod
39+
SQLMESH_DEBUG=1 sqlmesh --gateway $GATEWAY_NAME migrate
40+
SQLMESH_DEBUG=1 sqlmesh --gateway $GATEWAY_NAME diff prod
4141
popd

sqlmesh/core/macros.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,7 @@ def call_macro(
12911291

12921292

12931293
def _coerce(
1294-
expr: exp.Expression,
1294+
expr: t.Any,
12951295
typ: t.Any,
12961296
dialect: DialectType,
12971297
path: Path,
@@ -1300,7 +1300,7 @@ def _coerce(
13001300
"""Coerces the given expression to the specified type on a best-effort basis."""
13011301
base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'."
13021302
try:
1303-
if typ is None or typ is t.Any:
1303+
if typ is None or typ is t.Any or not isinstance(expr, exp.Expression):
13041304
return expr
13051305
base = t.get_origin(typ) or typ
13061306

sqlmesh/core/model/definition.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,15 @@ def _render(e: exp.Expression) -> str | int | float | bool:
621621
{k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name
622622
]
623623

624+
def render_signal_calls(self) -> t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]]:
625+
return {
626+
name: {
627+
k: seq_get(self._create_renderer(v).render() or [], 0) for k, v in kwargs.items()
628+
}
629+
for name, kwargs in self.signals
630+
if name
631+
}
632+
624633
def render_merge_filter(
625634
self,
626635
*,
@@ -2359,6 +2368,9 @@ def _create_model(
23592368

23602369
statements.extend(audit.query for audit in audit_definitions.values())
23612370

2371+
for _, kwargs in model.signals:
2372+
statements.extend(kwargs.values())
2373+
23622374
python_env = python_env or {}
23632375

23642376
make_python_env(

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,15 +909,15 @@ def check_ready_intervals(self, intervals: Intervals) -> Intervals:
909909
Note that this will handle gaps in the provided intervals. The returned intervals
910910
may introduce new gaps.
911911
"""
912-
signals = self.is_model and self.model.signals
912+
signals = self.is_model and self.model.render_signal_calls()
913913

914914
if not signals:
915915
return intervals
916916

917917
python_env = self.model.python_env
918918
env = prepare_env(python_env)
919919

920-
for signal_name, kwargs in signals:
920+
for signal_name, kwargs in signals.items():
921921
try:
922922
intervals = _check_ready_intervals(
923923
env[signal_name],

tests/core/test_snapshot.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytest_mock.plugin import MockerFixture
1212
from sqlglot import exp, to_column
1313

14+
from sqlmesh.core import constants as c
1415
from sqlmesh.core.audit import StandaloneAudit
1516
from sqlmesh.core.config import (
1617
AutoCategorizationMode,
@@ -35,6 +36,7 @@
3536
)
3637
from sqlmesh.core.model.kind import TimeColumn, ModelKindName
3738
from sqlmesh.core.node import IntervalUnit
39+
from sqlmesh.core.signal import signal
3840
from sqlmesh.core.snapshot import (
3941
DeployabilityIndex,
4042
QualifiedViewName,
@@ -2802,3 +2804,29 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
28022804
assert snapshot_b.intervals == [
28032805
(to_timestamp("2020-01-01"), to_timestamp("2020-01-06")),
28042806
]
2807+
2808+
2809+
def test_render_signal(make_snapshot):
2810+
@signal()
2811+
def check_types(batch, env: str, default: int = 0):
2812+
if env != "in_memory" or not default == 0:
2813+
raise
2814+
return True
2815+
2816+
sql_model = load_sql_based_model(
2817+
parse(
2818+
"""
2819+
MODEL (
2820+
name test_schema.test_model,
2821+
signals check_types(env := @gateway)
2822+
);
2823+
SELECT a FROM tbl;
2824+
"""
2825+
),
2826+
variables={
2827+
c.GATEWAY: "in_memory",
2828+
},
2829+
signal_definitions=signal.get_registry(),
2830+
)
2831+
snapshot_a = make_snapshot(sql_model)
2832+
assert snapshot_a.check_ready_intervals([(0, 1)]) == [(0, 1)]

0 commit comments

Comments
 (0)