Skip to content

Commit a17877b

Browse files
committed
Feat: Improve CLI and --explain output for restatements
1 parent 6eaa6eb commit a17877b

File tree

5 files changed

+106
-14
lines changed

5 files changed

+106
-14
lines changed

sqlmesh/core/console.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2022,7 +2022,34 @@ def _prompt_categorize(
20222022
plan = plan_builder.build()
20232023

20242024
if plan.restatements:
2025-
self._print("\n[bold]Restating models\n")
2025+
# A plan can have restatements for the following reasons:
2026+
# - The user specifically called `sqlmesh plan` with --restate-model.
2027+
# This creates a "restatement plan" which disallows all other changes and simply force-backfills
2028+
# the selected models and their downstream dependencies using the versions of the models stored in state.
2029+
# - There are no specific restatements (so changes are allowed) AND dev previews need to be computed.
2030+
# The "restatements" feature is currently reused for dev previews.
2031+
if plan.selected_models_to_restate:
2032+
# There were legitimate restatements, no dev previews
2033+
tree = Tree(
2034+
"[bold]Models selected for restatement[/bold]\n"
2035+
"This causes backfill of the model itself as well as affected downstream models"
2036+
)
2037+
model_fqn_to_snapshot = {s.name: s for s in plan.snapshots.values()}
2038+
for model_fqn in plan.selected_models_to_restate:
2039+
snapshot = model_fqn_to_snapshot[model_fqn]
2040+
display_name = snapshot.display_name(
2041+
plan.environment_naming_info,
2042+
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
2043+
dialect=self.dialect,
2044+
)
2045+
tree.add(
2046+
display_name
2047+
) # note: we deliberately dont show any intervals here; they get shown in the backfill section
2048+
self._print(tree)
2049+
else:
2050+
# We are computing dev previews, do not confuse the user by printing out something to do
2051+
# with restatements. Dev previews are already highlighted in the backfill step
2052+
pass
20262053
else:
20272054
self.show_environment_difference_summary(
20282055
plan.context_diff,

sqlmesh/core/plan/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def build(self) -> Plan:
338338
directly_modified=directly_modified,
339339
indirectly_modified=indirectly_modified,
340340
deployability_index=deployability_index,
341+
selected_models_to_restate=self._restate_models,
341342
restatements=restatements,
342343
start_override_per_model=self._start_override_per_model,
343344
end_override_per_model=end_override_per_model,

sqlmesh/core/plan/definition.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@ class Plan(PydanticModel, frozen=True):
5858
indirectly_modified: t.Dict[SnapshotId, t.Set[SnapshotId]]
5959

6060
deployability_index: DeployabilityIndex
61+
selected_models_to_restate: t.Optional[t.Set[str]] = None
62+
"""Models that have been explicitly selected for restatement by a user"""
6163
restatements: t.Dict[SnapshotId, Interval]
64+
"""
65+
All models being restated, which are typically the explicitly selected ones + their downstream dependencies.
66+
67+
Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty
68+
while :restatements is still populated with dev previews
69+
"""
70+
6271
start_override_per_model: t.Optional[t.Dict[str, datetime]]
6372
end_override_per_model: t.Optional[t.Dict[str, datetime]]
6473

sqlmesh/core/plan/explainer.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing as t
55
import logging
66
from dataclasses import dataclass
7+
from collections import defaultdict
78

89
from rich.console import Console as RichConsole
910
from rich.tree import Tree
@@ -22,7 +23,11 @@
2223
PlanEvaluator,
2324
)
2425
from sqlmesh.core.state_sync import StateReader
25-
from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotNameVersionLike
26+
from sqlmesh.core.snapshot.definition import (
27+
Snapshot,
28+
SnapshotInfoMixin,
29+
SnapshotNameVersionLike,
30+
)
2631
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2732
from sqlmesh.utils.date import to_ts
2833
from sqlmesh.utils.errors import SQLMeshError
@@ -76,8 +81,10 @@ class ExplainableRestatementStage(stages.RestatementStage):
7681
of what might happen when they ask for the plan to be explained
7782
"""
7883

79-
snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
80-
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""
84+
snapshot_intervals_to_clear: t.Dict[
85+
str, t.List[t.Tuple[Snapshot, SnapshotIntervalClearRequest]]
86+
]
87+
"""Which snapshots from other environments would have intervals cleared as part of restatement, grouped by name."""
8188

8289
@classmethod
8390
def from_restatement_stage(
@@ -86,17 +93,30 @@ def from_restatement_stage(
8693
state_reader: StateReader,
8794
plan: EvaluatablePlan,
8895
) -> ExplainableRestatementStage:
96+
loaded_snapshots = {s.snapshot_id: s for s in stage.all_snapshots.values()}
97+
8998
all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions(
9099
state_reader=state_reader,
91100
prod_restatements=plan.restatements,
92101
disable_restatement_models=plan.disabled_restatement_models,
93-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
102+
loaded_snapshots=loaded_snapshots,
94103
)
95104

105+
# extend loaded_snapshots with the remaining full Snapshot objects from all_restatement_intervals
106+
# so that we can generate physical table names for them while explaining what's going on
107+
remaining_snapshot_ids_to_load = set(all_restatement_intervals).difference(loaded_snapshots)
108+
loaded_snapshots.update(
109+
state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids_to_load)
110+
)
111+
112+
snapshot_intervals_to_clear = defaultdict(list)
113+
for snapshot_id, clear_request in all_restatement_intervals.items():
114+
snapshot_intervals_to_clear[clear_request.snapshot.name].append(
115+
(loaded_snapshots[snapshot_id], clear_request)
116+
)
117+
96118
return cls(
97-
snapshot_intervals_to_clear={
98-
s.snapshot.name: s for s in all_restatement_intervals.values()
99-
},
119+
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
100120
all_snapshots=stage.all_snapshots,
101121
)
102122

@@ -199,15 +219,50 @@ def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage
199219
def visit_restatement_stage(
200220
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
201221
) -> Tree:
202-
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
222+
tree = Tree(
223+
"[bold]Invalidate data intervals in state for development environments to prevent old data from being promoted[/bold]\n"
224+
"This only affects state and will not clear physical data from the tables until the next plan for each environment"
225+
)
203226

204227
if isinstance(stage, ExplainableRestatementStage) and (
205228
snapshot_intervals := stage.snapshot_intervals_to_clear
206229
):
207-
for clear_request in snapshot_intervals.values():
208-
display_name = self._display_name(clear_request.snapshot)
209-
interval = clear_request.interval
210-
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
230+
for name, requests in snapshot_intervals.items():
231+
display_name = model_display_name(
232+
name, self.environment_naming_info, self.default_catalog, self.dialect
233+
)
234+
235+
# group by environment for the console output
236+
by_environment: t.Dict[t.Optional[str], t.List[Snapshot]] = defaultdict(list)
237+
238+
interval_start = None
239+
interval_end = None
240+
241+
for snapshot, clear_request in requests:
242+
# used for the top level tree node
243+
interval_start, interval_end = clear_request.interval
244+
245+
if clear_request.sorted_environment_names:
246+
# snapshot is promoted in these environments
247+
for env in clear_request.sorted_environment_names:
248+
by_environment[env].append(snapshot)
249+
else:
250+
# snapshot is not currently promoted in any environment
251+
by_environment[None].append(snapshot)
252+
253+
if not interval_start or not interval_end:
254+
continue
255+
256+
node = tree.add(f"{display_name} [{to_ts(interval_start)} - {to_ts(interval_end)}]")
257+
258+
for env_name, snapshots_to_clear in by_environment.items():
259+
env_name = env_name or "(no env)"
260+
for snapshot in snapshots_to_clear:
261+
# note: we dont need a DeployabilityIndex and can just hardcode is_deployable=True.
262+
# The reason is that non-deployable data can never be restated so we only need to
263+
# bother clearing intervals for the deployable version of the table
264+
physical_table_name = snapshot.table_name(True)
265+
node.add(f"{env_name} -> {physical_table_name}")
211266

212267
return tree
213268

tests/cli/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_plan_restate_model(runner, tmp_path):
247247
)
248248
assert result.exit_code == 0
249249
assert_duckdb_test(result)
250-
assert "Restating models" in result.output
250+
assert "Models selected for restatement" in result.output
251251
assert "sqlmesh_example.full_model [full refresh" in result.output
252252
assert_model_batches_executed(result)
253253
assert "Virtual layer updated" not in result.output

0 commit comments

Comments
 (0)