Skip to content

Commit ff753e5

Browse files
authored
Fix!: Support snapshots of added models with forward-only parents (#1372)
1 parent 52f53fb commit ff753e5

File tree

5 files changed

+235
-35
lines changed

5 files changed

+235
-35
lines changed

sqlmesh/core/plan/definition.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -545,14 +545,16 @@ def _categorize_snapshots(self) -> None:
545545
returns a list of added and directly modified snapshots as well as the mapping of
546546
indirectly modified snapshots.
547547
"""
548-
for model_name, snapshot in self._snapshot_mapping.items():
549-
if snapshot.change_category:
548+
549+
# Iterating in DAG order since a category for a snapshot may depend on the categories
550+
# assigned to its upstream dependencies.
551+
for model_name in self._dag:
552+
snapshot = self._snapshot_mapping.get(model_name)
553+
if not snapshot or snapshot.change_category:
550554
continue
551-
upstream_model_names = self._dag.upstream(model_name)
552555

553556
if model_name in self.context_diff.modified_snapshots:
554557
is_directly_modified = self.context_diff.directly_modified(model_name)
555-
556558
if self.is_new_snapshot(snapshot):
557559
if self.forward_only:
558560
# In case of the forward only plan any modifications result in reuse of the
@@ -597,13 +599,17 @@ def _categorize_snapshots(self) -> None:
597599
and not any(
598600
self.context_diff.directly_modified(upstream)
599601
and not self._snapshot_mapping[upstream].version
600-
for upstream in upstream_model_names
602+
for upstream in self._dag.upstream(model_name)
601603
)
602604
):
603605
snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_BREAKING)
604606

605607
elif model_name in self.context_diff.added and self.is_new_snapshot(snapshot):
606-
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
608+
snapshot.categorize_as(
609+
SnapshotChangeCategory.FORWARD_ONLY
610+
if self._is_forward_only_model(model_name)
611+
else SnapshotChangeCategory.BREAKING
612+
)
607613

608614
def _ensure_valid_date_range(
609615
self, start: t.Optional[TimeLike], end: t.Optional[TimeLike]
@@ -636,12 +642,6 @@ def _ensure_no_forward_only_revert(self) -> None:
636642
"To proceed with the change, restamp this model's definition to produce a new version."
637643
)
638644

639-
def _ensure_no_forward_only_new_models(self) -> None:
640-
if self.forward_only and self.context_diff.added_materialized_models:
641-
raise PlanError(
642-
"New models that require materialization can't be added as part of the forward-only plan."
643-
)
644-
645645
def _ensure_no_broken_references(self) -> None:
646646
for snapshot in self.context_diff.snapshots.values():
647647
broken_references = self.context_diff.removed & snapshot.model.depends_on
@@ -683,7 +683,6 @@ def _refresh_dag_and_ignored_snapshots(self) -> None:
683683
self._ensure_new_env_with_changes()
684684
self._ensure_valid_date_range(self._start, self._end)
685685
self._ensure_no_forward_only_revert()
686-
self._ensure_no_forward_only_new_models()
687686
self._ensure_no_broken_references()
688687

689688
(
@@ -732,16 +731,29 @@ def _build_snapshots_and_dag(
732731
)
733732

734733
def _is_forward_only_model(self, model_name: str) -> bool:
735-
snapshot = self._snapshot_mapping.get(model_name)
736-
if snapshot and snapshot.model.forward_only:
734+
def _is_forward_only_expected(snapshot: Snapshot) -> bool:
735+
# Returns True if the snapshot is not categorized yet but is expected
736+
# to be categorized as forward-only. Checking the previous versions to make
737+
# sure that the snapshot doesn't represent a newly added model.
738+
return (
739+
snapshot.model.forward_only
740+
and not snapshot.change_category
741+
and bool(snapshot.previous_versions)
742+
)
743+
744+
snapshot = self._snapshot_mapping[model_name]
745+
if _is_forward_only_expected(snapshot):
737746
return True
738747

739748
for upstream in self._dag.upstream(model_name):
740749
upstream_snapshot = self._snapshot_mapping.get(upstream)
741750
if (
742751
upstream_snapshot
743752
and upstream_snapshot.is_paused
744-
and (upstream_snapshot.is_forward_only or upstream_snapshot.model.forward_only)
753+
and (
754+
upstream_snapshot.is_forward_only
755+
or _is_forward_only_expected(upstream_snapshot)
756+
)
745757
):
746758
return True
747759
return False

sqlmesh/core/snapshot/evaluator.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,19 @@ def _create_snapshot(
407407
evaluation_strategy = _evaluation_strategy(snapshot, self.adapter)
408408

409409
with self.adapter.transaction(TransactionType.DDL), self.adapter.session():
410+
if is_dev and not snapshot.previous_versions:
411+
# This is a FORWARD_ONLY snapshot which represents an added model. This means
412+
# that the physical table associated with it doesn't exist yet.
413+
logger.info(
414+
f"Detected a forward-only snapshot without previous versions: {snapshot.snapshot_id}"
415+
)
416+
non_dev_render_kwargs: t.Dict[str, t.Any] = {**render_kwargs, "is_dev": False}
417+
self.adapter.execute(snapshot.model.render_pre_statements(**non_dev_render_kwargs))
418+
evaluation_strategy.create(
419+
snapshot.model, snapshot.table_name(), **non_dev_render_kwargs
420+
)
421+
self.adapter.execute(snapshot.model.render_post_statements(**non_dev_render_kwargs))
422+
410423
self.adapter.execute(snapshot.model.render_pre_statements(**render_kwargs))
411424

412425
if is_dev and snapshot.is_materialized and self.adapter.SUPPORTS_CLONING:
@@ -480,9 +493,13 @@ def _demote_snapshot(
480493

481494
def _cleanup_snapshot(self, snapshot: SnapshotInfoLike) -> None:
482495
snapshot = snapshot.table_info
483-
table_names = [snapshot.table_name()]
484-
if snapshot.version != snapshot.fingerprint.to_version():
485-
table_names.append(snapshot.table_name(is_dev=True))
496+
497+
table_name = snapshot.table_name()
498+
dev_table_name = snapshot.table_name(is_dev=True)
499+
500+
table_names = [table_name]
501+
if table_name != dev_table_name:
502+
table_names.append(dev_table_name)
486503

487504
evaluation_strategy = _evaluation_strategy(snapshot, self.adapter)
488505

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Fix snapshots of added models with forward only parents."""
2+
import json
3+
import typing as t
4+
5+
from sqlglot import exp
6+
7+
from sqlmesh.utils.dag import DAG
8+
9+
10+
def migrate(state_sync: t.Any) -> None:
11+
engine_adapter = state_sync.engine_adapter
12+
schema = state_sync.schema
13+
snapshots_table = "_snapshots"
14+
environments_table = "_environments"
15+
if schema:
16+
snapshots_table = f"{schema}.{snapshots_table}"
17+
environments_table = f"{schema}.{environments_table}"
18+
19+
dag: DAG[t.Tuple[str, str]] = DAG()
20+
snapshot_mapping: t.Dict[t.Tuple[str, str], t.Dict[str, t.Any]] = {}
21+
22+
for identifier, snapshot in engine_adapter.fetchall(
23+
exp.select("identifier", "snapshot").from_(snapshots_table),
24+
quote_identifiers=True,
25+
):
26+
parsed_snapshot = json.loads(snapshot)
27+
28+
snapshot_id = (parsed_snapshot["name"], identifier)
29+
snapshot_mapping[snapshot_id] = parsed_snapshot
30+
31+
parent_ids = [
32+
(parent["name"], parent["identifier"]) for parent in parsed_snapshot["parents"]
33+
]
34+
dag.add(snapshot_id, parent_ids)
35+
36+
snapshots_to_delete = set()
37+
38+
for snapshot_id in dag:
39+
parsed_snapshot = snapshot_mapping[snapshot_id]
40+
is_breaking = parsed_snapshot.get("change_category") == 1
41+
has_previous_versions = bool(parsed_snapshot.get("previous_versions", []))
42+
43+
has_paused_forward_only_parent = False
44+
if is_breaking and not has_previous_versions:
45+
for upstream_id in dag.upstream(snapshot_id):
46+
upstream_snapshot = snapshot_mapping[upstream_id]
47+
upstream_change_category = upstream_snapshot.get("change_category")
48+
is_forward_only_upstream = upstream_change_category == 3
49+
if is_forward_only_upstream and not upstream_snapshot.get("unpaused_ts"):
50+
has_paused_forward_only_parent = True
51+
break
52+
53+
if has_paused_forward_only_parent:
54+
snapshots_to_delete.add(snapshot_id)
55+
56+
if snapshots_to_delete:
57+
where = t.cast(exp.Tuple, exp.convert((exp.column("name"), exp.column("identifier")))).isin(
58+
*snapshots_to_delete
59+
)
60+
engine_adapter.delete_from(snapshots_table, where)

tests/core/test_plan.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,26 +90,27 @@ def test_forward_only_dev(make_snapshot, mocker: MockerFixture):
9090
now_ds_mock.call_count == 2
9191

9292

93-
def test_forward_only_plan_new_models_not_allowed(make_snapshot, mocker: MockerFixture):
94-
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds")))
95-
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
93+
def test_forward_only_plan_added_models(make_snapshot, mocker: MockerFixture):
94+
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds")))
95+
96+
snapshot_b = make_snapshot(
97+
SqlModel(name="b", query=parse_one("select a, ds from a")), nodes={"a": snapshot_a.node}
98+
)
9699

97100
context_diff_mock = mocker.Mock()
98-
context_diff_mock.snapshots = {"a": snapshot_a}
99-
context_diff_mock.added = {"a"}
101+
context_diff_mock.snapshots = {"a": snapshot_a, "b": snapshot_b}
102+
context_diff_mock.added = {"b"}
100103
context_diff_mock.removed = set()
101-
context_diff_mock.modified_snapshots = {}
102-
context_diff_mock.new_snapshots = {}
103-
context_diff_mock.added_materialized_models = {"a"}
104-
105-
with pytest.raises(
106-
PlanError,
107-
match="New models that require materialization can't be added as part of the forward-only plan.",
108-
):
109-
Plan(context_diff_mock, forward_only=True)
104+
context_diff_mock.modified_snapshots = {"a": (snapshot_a, snapshot_a)}
105+
context_diff_mock.new_snapshots = {
106+
snapshot_a.snapshot_id: snapshot_a,
107+
snapshot_b.snapshot_id: snapshot_b,
108+
}
109+
context_diff_mock.added_materialized_models = {"b"}
110110

111-
context_diff_mock.added_materialized_models = set()
112111
Plan(context_diff_mock, forward_only=True)
112+
assert snapshot_a.change_category == SnapshotChangeCategory.FORWARD_ONLY
113+
assert snapshot_b.change_category == SnapshotChangeCategory.FORWARD_ONLY
113114

114115

115116
def test_paused_forward_only_parent(make_snapshot, mocker: MockerFixture):
@@ -577,6 +578,7 @@ def test_forward_only_models(make_snapshot, mocker: MockerFixture):
577578
kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True),
578579
)
579580
)
581+
updated_snapshot.previous_versions = snapshot.all_versions
580582

581583
context_diff_mock = mocker.Mock()
582584
context_diff_mock.snapshots = {"a": updated_snapshot}
@@ -607,6 +609,7 @@ def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFix
607609
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds")))
608610
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
609611
updated_snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 2 as a, ds")))
612+
updated_snapshot_a.previous_versions = snapshot_a.all_versions
610613

611614
snapshot_b = make_snapshot(
612615
SqlModel(
@@ -618,12 +621,14 @@ def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFix
618621
)
619622
snapshot_b.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
620623
updated_snapshot_b = make_snapshot(snapshot_b.model, nodes={"a": updated_snapshot_a.model})
624+
updated_snapshot_b.previous_versions = snapshot_b.all_versions
621625

622626
snapshot_c = make_snapshot(
623627
SqlModel(name="c", query=parse_one("select a, ds from b")), nodes={"b": snapshot_b.model}
624628
)
625629
snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING)
626630
updated_snapshot_c = make_snapshot(snapshot_c.model, nodes={"b": updated_snapshot_b.model})
631+
updated_snapshot_c.previous_versions = snapshot_c.all_versions
627632

628633
context_diff_mock = mocker.Mock()
629634
context_diff_mock.snapshots = {
@@ -660,6 +665,65 @@ def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFix
660665
assert updated_snapshot_c.change_category == SnapshotChangeCategory.FORWARD_ONLY
661666

662667

668+
def test_added_model_with_forward_only_parent(make_snapshot, mocker: MockerFixture):
669+
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds")))
670+
snapshot_a.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
671+
672+
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select a, ds from a")))
673+
674+
context_diff_mock = mocker.Mock()
675+
context_diff_mock.snapshots = {
676+
"a": snapshot_a,
677+
"b": snapshot_b,
678+
}
679+
context_diff_mock.removed = set()
680+
context_diff_mock.added = {"b"}
681+
context_diff_mock.added_materialized_models = set()
682+
context_diff_mock.modified_snapshots = {}
683+
context_diff_mock.new_snapshots = {
684+
snapshot_b.snapshot_id: snapshot_b,
685+
}
686+
context_diff_mock.has_snapshot_changes = True
687+
context_diff_mock.environment = "test_dev"
688+
context_diff_mock.previous_plan_id = "previous_plan_id"
689+
690+
Plan(context_diff_mock)
691+
assert snapshot_b.change_category == SnapshotChangeCategory.FORWARD_ONLY
692+
693+
694+
def test_added_forward_only_model(make_snapshot, mocker: MockerFixture):
695+
snapshot_a = make_snapshot(
696+
SqlModel(
697+
name="a",
698+
query=parse_one("select 1 as a, ds"),
699+
kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True),
700+
)
701+
)
702+
703+
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select a, ds from a")))
704+
705+
context_diff_mock = mocker.Mock()
706+
context_diff_mock.snapshots = {
707+
"a": snapshot_a,
708+
"b": snapshot_b,
709+
}
710+
context_diff_mock.removed = set()
711+
context_diff_mock.added = {"a", "b"}
712+
context_diff_mock.added_materialized_models = set()
713+
context_diff_mock.modified_snapshots = {}
714+
context_diff_mock.new_snapshots = {
715+
snapshot_a.snapshot_id: snapshot_a,
716+
snapshot_b.snapshot_id: snapshot_b,
717+
}
718+
context_diff_mock.has_snapshot_changes = True
719+
context_diff_mock.environment = "test_dev"
720+
context_diff_mock.previous_plan_id = "previous_plan_id"
721+
722+
Plan(context_diff_mock)
723+
assert snapshot_a.change_category == SnapshotChangeCategory.BREAKING
724+
assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING
725+
726+
663727
def test_disable_restatement(make_snapshot, mocker: MockerFixture):
664728
snapshot = make_snapshot(
665729
SqlModel(

0 commit comments

Comments
 (0)