Skip to content

Commit 87e572e

Browse files
authored
Fix: Make sure that the deployable table is used for pre-/post- statements during the snapshot table creation (#2282)
1 parent fc5a35b commit 87e572e

3 files changed

Lines changed: 64 additions & 12 deletions

File tree

sqlmesh/core/snapshot/definition.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,12 +1140,19 @@ def is_representative(self, snapshot: SnapshotIdLike) -> bool:
11401140

11411141
def with_non_deployable(self, snapshot: SnapshotIdLike) -> DeployabilityIndex:
11421142
"""Creates a new index with the given snapshot marked as non-deployable."""
1143-
snapshot_id = self._snapshot_id_key(snapshot.snapshot_id)
1143+
return self._add_snapshot(snapshot, False)
1144+
1145+
def with_deployable(self, snapshot: SnapshotIdLike) -> DeployabilityIndex:
1146+
"""Creates a new index with the given snapshot marked as deployable."""
1147+
return self._add_snapshot(snapshot, True)
1148+
1149+
def _add_snapshot(self, snapshot: SnapshotIdLike, deployable: bool) -> DeployabilityIndex:
1150+
snapshot_id = {self._snapshot_id_key(snapshot.snapshot_id)}
11441151
indexed_ids = self.indexed_ids
11451152
if self.is_opposite_index:
1146-
indexed_ids = indexed_ids | {snapshot_id}
1153+
indexed_ids = indexed_ids - snapshot_id if deployable else indexed_ids | snapshot_id
11471154
else:
1148-
indexed_ids = indexed_ids - {snapshot_id}
1155+
indexed_ids = indexed_ids | snapshot_id if deployable else indexed_ids - snapshot_id
11491156

11501157
return DeployabilityIndex(
11511158
indexed_ids=indexed_ids,

sqlmesh/core/snapshot/evaluator.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -598,20 +598,25 @@ def _create_snapshot(
598598
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
599599
is_snapshot_deployable = deployability_index.is_deployable(snapshot)
600600

601-
# Refers to self as non-deployable to successfully create self-referential tables / views.
602-
deployability_index = deployability_index.with_non_deployable(snapshot)
603-
604-
render_kwargs: t.Dict[str, t.Any] = dict(
601+
common_render_kwargs: t.Dict[str, t.Any] = dict(
605602
engine_adapter=self.adapter,
606603
snapshots=parent_snapshots_by_name,
607-
deployability_index=deployability_index,
608604
runtime_stage=RuntimeStage.CREATING,
609605
)
606+
pre_post_render_kwargs = dict(
607+
**common_render_kwargs,
608+
deployability_index=deployability_index.with_deployable(snapshot),
609+
)
610+
create_render_kwargs = dict(
611+
**common_render_kwargs,
612+
# Refers to self as non-deployable to successfully create self-referential tables / views.
613+
deployability_index=deployability_index.with_non_deployable(snapshot),
614+
)
610615

611616
evaluation_strategy = _evaluation_strategy(snapshot, self.adapter)
612617

613618
with self.adapter.transaction(), self.adapter.session(snapshot.model.session_properties):
614-
self.adapter.execute(snapshot.model.render_pre_statements(**render_kwargs))
619+
self.adapter.execute(snapshot.model.render_pre_statements(**pre_post_render_kwargs))
615620

616621
if (
617622
snapshot.is_forward_only
@@ -626,7 +631,7 @@ def _create_snapshot(
626631
logger.info(f"Cloning table '{source_table_name}' into '{target_table_name}'")
627632

628633
evaluation_strategy.create(
629-
snapshot, tmp_table_name, False, is_snapshot_deployable, **render_kwargs
634+
snapshot, tmp_table_name, False, is_snapshot_deployable, **create_render_kwargs
630635
)
631636
try:
632637
self.adapter.clone_table(target_table_name, snapshot.table_name(), replace=True)
@@ -643,10 +648,10 @@ def _create_snapshot(
643648
snapshot.table_name(is_deployable=is_table_deployable),
644649
is_table_deployable,
645650
is_snapshot_deployable,
646-
**render_kwargs,
651+
**create_render_kwargs,
647652
)
648653

649-
self.adapter.execute(snapshot.model.render_post_statements(**render_kwargs))
654+
self.adapter.execute(snapshot.model.render_post_statements(**pre_post_render_kwargs))
650655

651656
if on_complete is not None:
652657
on_complete(snapshot)

tests/core/test_snapshot_evaluator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,3 +1583,43 @@ def test_audit_wap(adapter_mock, make_snapshot):
15831583

15841584
adapter_mock.wap_table_name.assert_called_once_with(snapshot.table_name(), wap_id)
15851585
adapter_mock.wap_publish.assert_called_once_with(snapshot.table_name(), wap_id)
1586+
1587+
1588+
def test_create_post_statements_use_deployable_table(
1589+
mocker: MockerFixture, adapter_mock, make_snapshot
1590+
):
1591+
evaluator = SnapshotEvaluator(adapter_mock)
1592+
1593+
model = load_sql_based_model(
1594+
parse( # type: ignore
1595+
"""
1596+
MODEL (
1597+
name test_schema.test_model,
1598+
kind FULL,
1599+
dialect postgres,
1600+
);
1601+
1602+
CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a);
1603+
1604+
SELECT a::int FROM tbl;
1605+
1606+
CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a);
1607+
"""
1608+
),
1609+
)
1610+
1611+
snapshot = make_snapshot(model)
1612+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
1613+
1614+
expected_call = f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a" NULLS FIRST)'
1615+
1616+
evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())
1617+
1618+
call_args = adapter_mock.execute.call_args_list
1619+
pre_calls = call_args[0][0][0]
1620+
assert len(pre_calls) == 1
1621+
assert pre_calls[0].sql(dialect="postgres") == expected_call
1622+
1623+
post_calls = call_args[1][0][0]
1624+
assert len(post_calls) == 1
1625+
assert post_calls[0].sql(dialect="postgres") == expected_call

0 commit comments

Comments
 (0)