Skip to content

Commit 907dc44

Browse files
committed
Feat: prevent other processes seeing missing intervals during restatement
1 parent 81bac93 commit 907dc44

File tree

6 files changed

+960
-53
lines changed

6 files changed

+960
-53
lines changed

sqlmesh/core/console.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,23 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
551551
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
552552
"""Display list of models that failed during evaluation to the user."""
553553

554+
@abc.abstractmethod
555+
def log_models_updated_during_restatement(
556+
self,
557+
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
558+
environment: EnvironmentSummary,
559+
environment_naming_info: EnvironmentNamingInfo,
560+
default_catalog: t.Optional[str],
561+
) -> None:
562+
"""Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions
563+
564+
Args:
565+
snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples
566+
environment: which environment got updated while we were restating models
567+
environment_naming_info: how snapshots are named in that :environment (for display name purposes)
568+
default_catalog: the configured default catalog (for display name purposes)
569+
"""
570+
554571
@abc.abstractmethod
555572
def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
556573
"""Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message."""
@@ -771,6 +788,15 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
771788
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
772789
pass
773790

791+
def log_models_updated_during_restatement(
792+
self,
793+
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
794+
environment: EnvironmentSummary,
795+
environment_naming_info: EnvironmentNamingInfo,
796+
default_catalog: t.Optional[str],
797+
) -> None:
798+
pass
799+
774800
def log_destructive_change(
775801
self,
776802
snapshot_name: str,
@@ -2225,6 +2251,37 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
22252251
for node_name, msg in error_messages.items():
22262252
self._print(f" [red]{node_name}[/red]\n\n{msg}")
22272253

2254+
def log_models_updated_during_restatement(
2255+
self,
2256+
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
2257+
environment: EnvironmentSummary,
2258+
environment_naming_info: EnvironmentNamingInfo,
2259+
default_catalog: t.Optional[str] = None,
2260+
) -> None:
2261+
if snapshots:
2262+
tree = Tree(
2263+
f"[yellow]The following models had new versions deployed in plan '{environment.plan_id}' while data was being restated:[/yellow]"
2264+
)
2265+
2266+
for restated_snapshot, updated_snapshot in snapshots:
2267+
display_name = restated_snapshot.display_name(
2268+
environment_naming_info,
2269+
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
2270+
dialect=self.dialect,
2271+
)
2272+
current_branch = tree.add(display_name)
2273+
current_branch.add(f"restated version: '{restated_snapshot.version}'")
2274+
current_branch.add(f"currently active version: '{updated_snapshot.version}'")
2275+
2276+
self._print(tree)
2277+
2278+
self.log_warning(
2279+
f"\nThe '{environment.name}' environment currently points to [bold]different[/bold] versions of these models, not the versions that just got restated."
2280+
)
2281+
self._print(
2282+
"[yellow]If this is undesirable, please re-run this restatement plan which will apply it to the most recent versions of these models.[/yellow]\n"
2283+
)
2284+
22282285
def log_destructive_change(
22292286
self,
22302287
snapshot_name: str,

sqlmesh/core/plan/evaluator.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sqlmesh.core.console import Console, get_console
2323
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
2424
from sqlmesh.core.macros import RuntimeStage
25-
from sqlmesh.core.snapshot.definition import to_view_mapping
25+
from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo
2626
from sqlmesh.core.plan import stages
2727
from sqlmesh.core.plan.definition import EvaluatablePlan
2828
from sqlmesh.core.scheduler import Scheduler
@@ -284,32 +284,62 @@ def visit_audit_only_run_stage(
284284
def visit_restatement_stage(
285285
self, stage: stages.RestatementStage, plan: EvaluatablePlan
286286
) -> None:
287-
snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()}
288-
289-
# Restating intervals on prod plans should mean that the intervals are cleared across
290-
# all environments, not just the version currently in prod
291-
# This ensures that work done in dev environments can still be promoted to prod
292-
# by forcing dev environments to re-run intervals that changed in prod
287+
# Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled
288+
# (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments.
289+
#
290+
# This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to
291+
# re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing)
292+
#
293+
# It also means that any new dev environments created while this restatement plan was running also get the
294+
# correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan
295+
# was created, which could have been several hours ago if there was a lot of data to restate.
293296
#
294297
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
295-
snapshot_intervals_to_restate.update(
296-
{
297-
(s.table_info, s.interval)
298-
for s in identify_restatement_intervals_across_snapshot_versions(
299-
state_reader=self.state_sync,
300-
prod_restatements=plan.restatements,
301-
disable_restatement_models=plan.disabled_restatement_models,
302-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
303-
current_ts=to_timestamp(plan.execution_time or now()),
304-
).values()
305-
}
298+
299+
intervals_to_clear = identify_restatement_intervals_across_snapshot_versions(
300+
state_reader=self.state_sync,
301+
prod_restatements=plan.restatements,
302+
disable_restatement_models=plan.disabled_restatement_models,
303+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
304+
current_ts=to_timestamp(plan.execution_time or now()),
306305
)
307306

307+
if not intervals_to_clear:
308+
# Nothing to do
309+
return
310+
308311
self.state_sync.remove_intervals(
309-
snapshot_intervals=list(snapshot_intervals_to_restate),
312+
snapshot_intervals=[(s.table_info, s.interval) for s in intervals_to_clear.values()],
310313
remove_shared_versions=plan.is_prod,
311314
)
312315

316+
# While the restatements were being processed, did any of the snapshots being restated get new versions deployed?
317+
# If they did, they will not reflect the data that just got restated, so we need to notify the user
318+
if deployed_env := self.state_sync.get_environment(plan.environment.name):
319+
promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots}
320+
321+
deployed_during_restatement: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]] = []
322+
323+
for name in plan.restatements:
324+
snapshot = stage.all_snapshots[name]
325+
version = snapshot.table_info.version
326+
if (
327+
prod_snapshot := promoted_snapshots_by_name.get(name)
328+
) and prod_snapshot.version != version:
329+
deployed_during_restatement.append(
330+
(snapshot.table_info, prod_snapshot.table_info)
331+
)
332+
333+
if deployed_during_restatement:
334+
self.console.log_models_updated_during_restatement(
335+
deployed_during_restatement,
336+
deployed_env.summary,
337+
plan.environment.naming_info,
338+
self.default_catalog,
339+
)
340+
# note: the plan will automatically fail at the promotion stage with a ConflictingPlanError because the environment was changed by another plan
341+
# so there is no need to explicitly fail the plan here
342+
313343
def visit_environment_record_update_stage(
314344
self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan
315345
) -> None:

sqlmesh/core/plan/explainer.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1+
from __future__ import annotations
2+
13
import abc
24
import typing as t
35
import logging
6+
from dataclasses import dataclass
47

58
from rich.console import Console as RichConsole
69
from rich.tree import Tree
710
from sqlglot.dialects.dialect import DialectType
811
from sqlmesh.core import constants as c
912
from sqlmesh.core.console import Console, TerminalConsole, get_console
1013
from sqlmesh.core.environment import EnvironmentNamingInfo
14+
from sqlmesh.core.snapshot.definition import DeployabilityIndex
15+
from sqlmesh.core.plan.common import (
16+
SnapshotIntervalClearRequest,
17+
identify_restatement_intervals_across_snapshot_versions,
18+
)
1119
from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals
1220
from sqlmesh.core.plan import stages
1321
from sqlmesh.core.plan.evaluator import (
@@ -45,6 +53,15 @@ def evaluate(
4553
explainer_console = _get_explainer_console(
4654
self.console, plan.environment, self.default_catalog
4755
)
56+
57+
# add extra metadata that's only needed at this point for better --explain output
58+
plan_stages = [
59+
ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan)
60+
if isinstance(stage, stages.RestatementStage)
61+
else stage
62+
for stage in plan_stages
63+
]
64+
4865
explainer_console.explain(plan_stages)
4966

5067

@@ -54,6 +71,61 @@ def explain(self, stages: t.List[stages.PlanStage]) -> None:
5471
pass
5572

5673

74+
@dataclass
75+
class ExplainableRestatementStage(stages.RestatementStage):
76+
"""
77+
This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication
78+
of what might happen when they ask for the plan to be explained
79+
"""
80+
81+
snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
82+
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""
83+
84+
deployability_index: DeployabilityIndex
85+
"""Deployability of those snapshots (which arent necessarily present in the current plan so we cant use the
86+
plan deployability index), used for outputting physical table names"""
87+
88+
@classmethod
89+
def from_restatement_stage(
90+
cls: t.Type[ExplainableRestatementStage],
91+
stage: stages.RestatementStage,
92+
state_reader: StateReader,
93+
plan: EvaluatablePlan,
94+
) -> ExplainableRestatementStage:
95+
all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions(
96+
state_reader=state_reader,
97+
prod_restatements=plan.restatements,
98+
disable_restatement_models=plan.disabled_restatement_models,
99+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
100+
)
101+
102+
snapshot_intervals_to_clear = {}
103+
deployability_index = DeployabilityIndex.all_deployable()
104+
105+
if all_restatement_intervals:
106+
snapshot_intervals_to_clear = {
107+
s_id.name: r for s_id, r in all_restatement_intervals.items()
108+
}
109+
110+
# creating a deployability index over the "snapshot intervals to clear"
111+
# allows us to print the physical names of the tables affected in the console output
112+
# note that we can't use the DeployabilityIndex on the plan because it only includes
113+
# snapshots for the current environment, not across all environments
114+
deployability_index = DeployabilityIndex.create(
115+
snapshots=state_reader.get_snapshots(
116+
[s.snapshot_id for s in snapshot_intervals_to_clear.values()]
117+
),
118+
start=plan.start,
119+
start_override_per_model=plan.start_override_per_model,
120+
)
121+
122+
return cls(
123+
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
124+
deployability_index=deployability_index,
125+
all_snapshots=stage.all_snapshots,
126+
)
127+
128+
57129
MAX_TREE_LENGTH = 10
58130

59131

@@ -146,11 +218,22 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree:
146218
tree.add(display_name)
147219
return tree
148220

149-
def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree:
221+
def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree:
222+
return self.visit_restatement_stage(stage)
223+
224+
def visit_restatement_stage(
225+
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
226+
) -> Tree:
150227
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
151-
for snapshot_table_info, interval in stage.snapshot_intervals.items():
152-
display_name = self._display_name(snapshot_table_info)
153-
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
228+
229+
if isinstance(stage, ExplainableRestatementStage) and (
230+
snapshot_intervals := stage.snapshot_intervals_to_clear
231+
):
232+
for clear_request in snapshot_intervals.values():
233+
display_name = self._display_name(clear_request.table_info)
234+
interval = clear_request.interval
235+
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
236+
154237
return tree
155238

156239
def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree:

sqlmesh/core/plan/stages.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Snapshot,
1313
SnapshotTableInfo,
1414
SnapshotId,
15-
Interval,
1615
)
1716

1817

@@ -98,14 +97,19 @@ class AuditOnlyRunStage:
9897

9998
@dataclass
10099
class RestatementStage:
101-
"""Restate intervals for given snapshots.
100+
"""Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod.
101+
102+
This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed.
103+
The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have
104+
been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time.
105+
106+
Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments.
102107
103108
Args:
104-
snapshot_intervals: Intervals to restate.
105-
all_snapshots: All snapshots in the plan by name.
109+
all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their
110+
intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots
106111
"""
107112

108-
snapshot_intervals: t.Dict[SnapshotTableInfo, Interval]
109113
all_snapshots: t.Dict[str, Snapshot]
110114

111115

@@ -321,10 +325,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
321325
if audit_only_snapshots:
322326
stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values())))
323327

324-
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
325-
if restatement_stage:
326-
stages.append(restatement_stage)
327-
328328
if missing_intervals_before_promote:
329329
stages.append(
330330
BackfillStage(
@@ -349,6 +349,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
349349
)
350350
)
351351

352+
# note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage)
353+
# needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them.
354+
# in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data.
355+
# we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces
356+
# data for existing intervals and does not produce new ones
357+
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
358+
if restatement_stage:
359+
stages.append(restatement_stage)
360+
352361
stages.append(
353362
EnvironmentRecordUpdateStage(
354363
no_gaps_snapshot_names={s.name for s in before_promote_snapshots}
@@ -443,15 +452,12 @@ def _get_after_all_stage(
443452
def _get_restatement_stage(
444453
self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]
445454
) -> t.Optional[RestatementStage]:
446-
snapshot_intervals_to_restate = {}
447-
for name, interval in plan.restatements.items():
448-
restated_snapshot = snapshots_by_name[name]
449-
restated_snapshot.remove_interval(interval)
450-
snapshot_intervals_to_restate[restated_snapshot.table_info] = interval
451-
if not snapshot_intervals_to_restate or plan.is_dev:
455+
if not plan.restatements or plan.is_dev:
456+
# The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod
452457
return None
458+
453459
return RestatementStage(
454-
snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name
460+
all_snapshots=snapshots_by_name,
455461
)
456462

457463
def _get_physical_layer_update_stage(

0 commit comments

Comments
 (0)