Skip to content

Commit caa6f83

Browse files
authored
feat: add execution context to signals (#4123)
1 parent 046b43a commit caa6f83

File tree

9 files changed

+95
-41
lines changed

9 files changed

+95
-41
lines changed

docs/guides/signals.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,18 @@ MODEL (
116116

117117
SELECT @start_ds AS ds
118118
```
119+
120+
### Accessing execution context / engine adapter
121+
It is possible to access the execution context in a signal and access the engine adapter (warehouse connection).
122+
123+
```python
124+
import typing as t
125+
126+
from sqlmesh import signal, DatetimeRanges, ExecutionContext
127+
128+
129+
# add the context argument to your function
130+
@signal()
131+
def one_week_ago(batch: DatetimeRanges, context: ExecutionContext) -> t.Union[bool, DatetimeRanges]:
132+
return len(context.engine_adapter.fetchdf("SELECT 1")) > 1
133+
```

sqlmesh/core/macros.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ def send(
214214
raise SQLMeshError(f"Macro '{name}' does not exist.")
215215

216216
try:
217-
return call_macro(func, self.dialect, self._path, self, *args, **kwargs) # type: ignore
217+
return call_macro(
218+
func, self.dialect, self._path, provided_args=(self, *args), provided_kwargs=kwargs
219+
) # type: ignore
218220
except Exception as e:
219221
print_exception(e, self.python_env)
220222
raise MacroEvalError("Error trying to eval macro.") from e
@@ -1286,12 +1288,21 @@ def call_macro(
12861288
func: t.Callable,
12871289
dialect: DialectType,
12881290
path: Path,
1289-
*args: t.Any,
1290-
**kwargs: t.Any,
1291+
provided_args: t.Tuple[t.Any, ...],
1292+
provided_kwargs: t.Dict[str, t.Any],
1293+
**optional_kwargs: t.Any,
12911294
) -> t.Any:
12921295
# Bind the macro's actual parameters to its formal parameters
12931296
sig = inspect.signature(func)
1294-
bound = sig.bind(*args, **kwargs)
1297+
1298+
if optional_kwargs:
1299+
provided_kwargs = provided_kwargs.copy()
1300+
1301+
for k, v in optional_kwargs.items():
1302+
if k in sig.parameters:
1303+
provided_kwargs[k] = v
1304+
1305+
bound = sig.bind(*provided_args, **provided_kwargs)
12951306
bound.apply_defaults()
12961307

12971308
try:

sqlmesh/core/scheduler.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
):
9696
self.state_sync = state_sync
9797
self.snapshots = {s.snapshot_id: s for s in snapshots}
98+
self.snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}
9899
self.snapshot_per_version = _resolve_one_snapshot_per_version(self.snapshots.values())
99100
self.default_catalog = default_catalog
100101
self.snapshot_evaluator = snapshot_evaluator
@@ -348,7 +349,11 @@ def run(
348349

349350
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
350351

351-
def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]:
352+
def batch_intervals(
353+
self,
354+
merged_intervals: SnapshotToIntervals,
355+
deployability_index: t.Optional[DeployabilityIndex],
356+
) -> t.Dict[Snapshot, Intervals]:
352357
dag = snapshots_to_dag(merged_intervals)
353358

354359
snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
@@ -369,7 +374,20 @@ def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snaps
369374
continue
370375
snapshot, intervals = snapshot_intervals[snapshot_id]
371376
unready = set(intervals)
372-
intervals = snapshot.check_ready_intervals(intervals)
377+
378+
from sqlmesh.core.context import ExecutionContext
379+
380+
adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)
381+
382+
context = ExecutionContext(
383+
adapter,
384+
self.snapshots_by_name,
385+
deployability_index,
386+
default_dialect=adapter.dialect,
387+
default_catalog=self.default_catalog,
388+
)
389+
390+
intervals = snapshot.check_ready_intervals(intervals, context)
373391
unready -= set(intervals)
374392

375393
for parent in snapshot.parents:
@@ -424,7 +442,7 @@ def run_merged_intervals(
424442
"""
425443
execution_time = execution_time or now_timestamp()
426444

427-
batched_intervals = self.batch_intervals(merged_intervals)
445+
batched_intervals = self.batch_intervals(merged_intervals, deployability_index)
428446

429447
self.console.start_evaluation_progress(
430448
{snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()},
@@ -434,8 +452,6 @@ def run_merged_intervals(
434452

435453
dag = self._dag(batched_intervals)
436454

437-
snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}
438-
439455
if run_environment_statements:
440456
environment_statements = self.state_sync.get_environment_statements(
441457
environment_naming_info.name
@@ -446,7 +462,7 @@ def run_merged_intervals(
446462
runtime_stage=RuntimeStage.BEFORE_ALL,
447463
environment_naming_info=environment_naming_info,
448464
default_catalog=self.default_catalog,
449-
snapshots=snapshots_by_name,
465+
snapshots=self.snapshots_by_name,
450466
start=start,
451467
end=end,
452468
execution_time=execution_time,
@@ -459,7 +475,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
459475
snapshot_name, ((start, end), batch_idx) = node
460476
if batch_idx == -1:
461477
return
462-
snapshot = snapshots_by_name[snapshot_name]
478+
snapshot = self.snapshots_by_name[snapshot_name]
463479

464480
self.console.start_snapshot_evaluation_progress(snapshot)
465481

@@ -520,7 +536,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
520536
runtime_stage=RuntimeStage.AFTER_ALL,
521537
environment_naming_info=environment_naming_info,
522538
default_catalog=self.default_catalog,
523-
snapshots=snapshots_by_name,
539+
snapshots=self.snapshots_by_name,
524540
start=start,
525541
end=end,
526542
execution_time=execution_time,

sqlmesh/core/snapshot/definition.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
if t.TYPE_CHECKING:
4747
from sqlglot.dialects.dialect import DialectType
4848
from sqlmesh.core.environment import EnvironmentNamingInfo
49+
from sqlmesh.core.context import ExecutionContext
4950

5051
Interval = t.Tuple[int, int]
5152
Intervals = t.List[Interval]
@@ -940,7 +941,7 @@ def missing_intervals(
940941
model_end_ts,
941942
)
942943

943-
def check_ready_intervals(self, intervals: Intervals) -> Intervals:
944+
def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext) -> Intervals:
944945
"""Returns a list of intervals that are considered ready by the provided signal.
945946
946947
Note that this will handle gaps in the provided intervals. The returned intervals
@@ -959,6 +960,7 @@ def check_ready_intervals(self, intervals: Intervals) -> Intervals:
959960
intervals = _check_ready_intervals(
960961
env[signal_name],
961962
intervals,
963+
context,
962964
dialect=self.model.dialect,
963965
path=self.model._path,
964966
kwargs=kwargs,
@@ -2148,6 +2150,7 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]:
21482150
def _check_ready_intervals(
21492151
check: t.Callable,
21502152
intervals: Intervals,
2153+
context: ExecutionContext,
21512154
dialect: DialectType = None,
21522155
path: Path = Path(),
21532156
kwargs: t.Optional[t.Dict] = None,
@@ -2158,7 +2161,14 @@ def _check_ready_intervals(
21582161
batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch]
21592162

21602163
try:
2161-
ready_intervals = call_macro(check, dialect, path, batch, **(kwargs or {}))
2164+
ready_intervals = call_macro(
2165+
check,
2166+
dialect,
2167+
path,
2168+
provided_args=(batch,),
2169+
provided_kwargs=(kwargs or {}),
2170+
context=context,
2171+
)
21622172
except Exception:
21632173
raise SQLMeshError("Error evaluating signal")
21642174

sqlmesh/core/snapshot/evaluator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def create(
321321

322322
def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.Set[str]:
323323
logger.info("Listing data objects in schema %s", schema.sql())
324-
objs = self._get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
324+
objs = self.get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
325325
return {obj.name for obj in objs}
326326

327327
with self.concurrent_context():
@@ -409,7 +409,7 @@ def migrate(
409409
s,
410410
snapshots,
411411
allow_destructive_snapshots,
412-
self._get_adapter(s.model_gateway),
412+
self.get_adapter(s.model_gateway),
413413
deployability_index,
414414
),
415415
self.ddl_concurrent_tasks,
@@ -437,7 +437,7 @@ def cleanup(
437437
lambda s: self._cleanup_snapshot(
438438
s,
439439
snapshots_to_dev_table_only[s.snapshot_id],
440-
self._get_adapter(
440+
self.get_adapter(
441441
snapshot_gateways.get(s.snapshot_id.name) if snapshot_gateways else None
442442
),
443443
on_complete,
@@ -471,7 +471,7 @@ def audit(
471471
kwargs: Additional kwargs to pass to the renderer.
472472
"""
473473
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
474-
adapter = self._get_adapter(snapshot.model_gateway)
474+
adapter = self.get_adapter(snapshot.model_gateway)
475475

476476
if not snapshot.version:
477477
raise ConfigError(
@@ -605,7 +605,7 @@ def _evaluate_snapshot(
605605
else snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
606606
)
607607

608-
adapter = self._get_adapter(model.gateway)
608+
adapter = self.get_adapter(model.gateway)
609609
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
610610

611611
# https://github.com/TobikoData/sqlmesh/issues/2609
@@ -764,7 +764,7 @@ def _create_snapshot(
764764

765765
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
766766

767-
adapter = self._get_adapter(snapshot.model.gateway)
767+
adapter = self.get_adapter(snapshot.model.gateway)
768768
create_render_kwargs: t.Dict[str, t.Any] = dict(
769769
engine_adapter=adapter,
770770
snapshots=parent_snapshots_by_name(snapshot, snapshots),
@@ -994,7 +994,7 @@ def _wap_publish_snapshot(
994994
) -> None:
995995
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
996996
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
997-
adapter = self._get_adapter(snapshot.model_gateway)
997+
adapter = self.get_adapter(snapshot.model_gateway)
998998
adapter.wap_publish(table_name, wap_id)
999999

10001000
def _audit(
@@ -1021,7 +1021,7 @@ def _audit(
10211021
blocking = audit_args.pop("blocking", None)
10221022
blocking = blocking == exp.true() if blocking else audit.blocking
10231023

1024-
adapter = self._get_adapter(snapshot.model_gateway)
1024+
adapter = self.get_adapter(snapshot.model_gateway)
10251025

10261026
kwargs = {
10271027
"start": start,
@@ -1068,10 +1068,10 @@ def _create_schemas(
10681068
for schema_name, catalog in unique_schemas:
10691069
schema = schema_(schema_name, catalog)
10701070
logger.info("Creating schema '%s'", schema)
1071-
adapter = self._get_adapter(gateways.get(schema)) if gateways else self.adapter
1071+
adapter = self.get_adapter(gateways.get(schema)) if gateways else self.adapter
10721072
adapter.create_schema(schema)
10731073

1074-
def _get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
1074+
def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
10751075
"""Returns the adapter for the specified gateway or the default adapter if none is provided."""
10761076
if gateway:
10771077
if adapter := self.adapters.get(gateway):
@@ -1089,7 +1089,7 @@ def _execute_create(
10891089
rendered_physical_properties: t.Dict[str, exp.Expression],
10901090
dry_run: bool,
10911091
) -> None:
1092-
adapter = self._get_adapter(snapshot.model.gateway)
1092+
adapter = self.get_adapter(snapshot.model.gateway)
10931093
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
10941094

10951095
# It can still be useful for some strategies to know if the snapshot was actually deployable

tests/core/test_scheduler.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlglot import parse_one, parse
66
from sqlglot.helper import first
77

8-
from sqlmesh.core.context import Context
8+
from sqlmesh.core.context import Context, ExecutionContext
99
from sqlmesh.core.environment import EnvironmentNamingInfo
1010
from sqlmesh.core.model import load_sql_based_model
1111
from sqlmesh.core.model.definition import AuditResult, SqlModel
@@ -66,17 +66,17 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context
6666

6767

6868
@pytest.fixture
69-
def get_batched_missing_intervals() -> (
70-
t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]
71-
):
69+
def get_batched_missing_intervals(
70+
mocker: MockerFixture,
71+
) -> t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]:
7272
def _get_batched_missing_intervals(
7373
scheduler: Scheduler,
7474
start: TimeLike,
7575
end: TimeLike,
7676
execution_time: t.Optional[TimeLike] = None,
7777
) -> SnapshotToIntervals:
7878
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
79-
return scheduler.batch_intervals(merged_intervals)
79+
return scheduler.batch_intervals(merged_intervals, mocker.Mock())
8080

8181
return _get_batched_missing_intervals
8282

@@ -622,7 +622,9 @@ def test_interval_diff():
622622

623623
def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals):
624624
@signal()
625-
def signal_a(batch: DatetimeRanges):
625+
def signal_a(batch: DatetimeRanges, context: ExecutionContext):
626+
if not hasattr(context, "engine_adapter"):
627+
raise
626628
return [batch[0], batch[1]]
627629

628630
@signal()

tests/core/test_snapshot.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2499,23 +2499,23 @@ def test_contiguous_intervals():
24992499

25002500
def test_check_ready_intervals(mocker: MockerFixture):
25012501
def assert_always_signal(intervals):
2502-
assert _check_ready_intervals(lambda _: True, intervals) == intervals
2502+
assert _check_ready_intervals(lambda _: True, intervals, mocker.Mock()) == intervals
25032503

25042504
assert_always_signal([])
25052505
assert_always_signal([(0, 1)])
25062506
assert_always_signal([(0, 1), (1, 2)])
25072507
assert_always_signal([(0, 1), (2, 3)])
25082508

25092509
def assert_never_signal(intervals):
2510-
assert _check_ready_intervals(lambda _: False, intervals) == []
2510+
assert _check_ready_intervals(lambda _: False, intervals, mocker.Mock()) == []
25112511

25122512
assert_never_signal([])
25132513
assert_never_signal([(0, 1)])
25142514
assert_never_signal([(0, 1), (1, 2)])
25152515
assert_never_signal([(0, 1), (2, 3)])
25162516

25172517
def assert_empty_signal(intervals):
2518-
assert _check_ready_intervals(lambda _: [], intervals) == []
2518+
assert _check_ready_intervals(lambda _: [], intervals, mocker.Mock()) == []
25192519

25202520
assert_empty_signal([])
25212521
assert_empty_signal([(0, 1)])
@@ -2532,7 +2532,7 @@ def assert_check_intervals(
25322532
):
25332533
mock = mocker.Mock()
25342534
mock.side_effect = [to_intervals(r) for r in ready]
2535-
_check_ready_intervals(mock, intervals) == expected
2535+
_check_ready_intervals(mock, intervals, mocker.Mock()) == expected
25362536

25372537
assert_check_intervals([], [], [])
25382538
assert_check_intervals([(0, 1)], [[]], [])
@@ -2894,7 +2894,7 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
28942894
]
28952895

28962896

2897-
def test_render_signal(make_snapshot):
2897+
def test_render_signal(make_snapshot, mocker):
28982898
@signal()
28992899
def check_types(batch, env: str, default: int = 0):
29002900
if env != "in_memory" or not default == 0:
@@ -2917,4 +2917,4 @@ def check_types(batch, env: str, default: int = 0):
29172917
signal_definitions=signal.get_registry(),
29182918
)
29192919
snapshot_a = make_snapshot(sql_model)
2920-
assert snapshot_a.check_ready_intervals([(0, 1)]) == [(0, 1)]
2920+
assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)]

tests/core/test_snapshot_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3751,9 +3751,9 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot):
37513751

37523752
assert len(evaluator.adapters) == 3
37533753
assert evaluator.adapter == engine_adapters["default"]
3754-
assert evaluator._get_adapter() == engine_adapters["default"]
3755-
assert evaluator._get_adapter("third") == engine_adapters["third"]
3756-
assert evaluator._get_adapter("secondary") == engine_adapters["secondary"]
3754+
assert evaluator.get_adapter() == engine_adapters["default"]
3755+
assert evaluator.get_adapter("third") == engine_adapters["third"]
3756+
assert evaluator.get_adapter("secondary") == engine_adapters["secondary"]
37573757

37583758
model = load_sql_based_model(
37593759
parse( # type: ignore

0 commit comments

Comments
 (0)