Skip to content

Commit 203b74e

Browse files
authored
Fix: When restating prod, clear intervals across all related snapshots, not just promoted ones (#5274)
1 parent 68fa7bc commit 203b74e

File tree

11 files changed

+447
-116
lines changed

11 files changed

+447
-116
lines changed

sqlmesh/core/plan/common.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
2+
import typing as t
3+
import logging
4+
from dataclasses import dataclass, field
25

3-
from sqlmesh.core.snapshot import Snapshot
6+
from sqlmesh.core.state_sync import StateReader
7+
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdAndVersion, SnapshotNameVersion
8+
from sqlmesh.core.snapshot.definition import Interval
9+
from sqlmesh.utils.dag import DAG
10+
from sqlmesh.utils.date import now_timestamp
11+
12+
logger = logging.getLogger(__name__)
413

514

615
def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool:
@@ -27,3 +36,176 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
2736
# If the partitioning hasn't changed, then we don't need to rebuild
2837
return False
2938
return True
39+
40+
41+
@dataclass
42+
class SnapshotIntervalClearRequest:
43+
# affected snapshot
44+
snapshot: SnapshotIdAndVersion
45+
46+
# which interval to clear
47+
interval: Interval
48+
49+
# which environments this snapshot is currently promoted
50+
# note that this can be empty if the snapshot exists because its ttl has not expired
51+
# but it is not part of any particular environment
52+
environment_names: t.Set[str] = field(default_factory=set)
53+
54+
@property
55+
def snapshot_id(self) -> SnapshotId:
56+
return self.snapshot.snapshot_id
57+
58+
@property
59+
def sorted_environment_names(self) -> t.List[str]:
60+
return list(sorted(self.environment_names))
61+
62+
63+
def identify_restatement_intervals_across_snapshot_versions(
64+
state_reader: StateReader,
65+
prod_restatements: t.Dict[str, Interval],
66+
disable_restatement_models: t.Set[str],
67+
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
68+
current_ts: t.Optional[int] = None,
69+
) -> t.Dict[SnapshotId, SnapshotIntervalClearRequest]:
70+
"""
71+
Given a map of snapshot names + intervals to restate in prod:
72+
- Look up matching snapshots (match based on name - regardless of version, to get all versions)
73+
- For each match, also match downstream snapshots in each dev environment while filtering out models that have restatement disabled
74+
- Return a list of all snapshots that are affected + the interval that needs to be cleared for each
75+
76+
The goal here is to produce a list of intervals to invalidate across all dev snapshots so that a subsequent plan or
77+
cadence run in those environments causes the intervals to be repopulated.
78+
"""
79+
if not prod_restatements:
80+
return {}
81+
82+
# Although :loaded_snapshots is sourced from RestatementStage.all_snapshots, since the only time we ever need
83+
# to clear intervals across all environments is for prod, the :loaded_snapshots here are always from prod
84+
prod_name_versions: t.Set[SnapshotNameVersion] = {
85+
s.name_version for s in loaded_snapshots.values()
86+
}
87+
88+
snapshot_intervals_to_clear: t.Dict[SnapshotId, SnapshotIntervalClearRequest] = {}
89+
90+
for env_summary in state_reader.get_environments_summary():
91+
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
92+
env = state_reader.get_environment(env_summary.name)
93+
if not env:
94+
logger.warning("Environment %s not found", env_summary.name)
95+
continue
96+
97+
snapshots_by_name = {s.name: s.table_info for s in env.snapshots}
98+
99+
# We dont just restate matching snapshots, we also have to restate anything downstream of them
100+
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
101+
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})
102+
103+
for restate_snapshot_name, interval in prod_restatements.items():
104+
if restate_snapshot_name not in snapshots_by_name:
105+
# snapshot is not promoted in this environment
106+
continue
107+
108+
affected_snapshot_names = [
109+
x
110+
for x in ([restate_snapshot_name] + env_dag.downstream(restate_snapshot_name))
111+
if x not in disable_restatement_models
112+
]
113+
114+
for affected_snapshot_name in affected_snapshot_names:
115+
affected_snapshot = snapshots_by_name[affected_snapshot_name]
116+
117+
# Don't clear intervals for a dev snapshot if it shares the same physical version with prod.
118+
# Otherwise, prod will be affected by what should be a dev operation
119+
if affected_snapshot.name_version in prod_name_versions:
120+
continue
121+
122+
clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id)
123+
if not clear_request:
124+
clear_request = SnapshotIntervalClearRequest(
125+
snapshot=affected_snapshot.id_and_version, interval=interval
126+
)
127+
snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request
128+
129+
clear_request.environment_names |= set([env.name])
130+
131+
# snapshot_intervals_to_clear now contains the entire hierarchy of affected snapshots based
132+
# on building the DAG for each environment and including downstream snapshots
133+
# but, what if there are affected snapshots that arent part of any environment?
134+
unique_snapshot_names = set(snapshot_id.name for snapshot_id in snapshot_intervals_to_clear)
135+
136+
current_ts = current_ts or now_timestamp()
137+
all_matching_non_prod_snapshots = {
138+
s.snapshot_id: s
139+
for s in state_reader.get_snapshots_by_names(
140+
snapshot_names=unique_snapshot_names, current_ts=current_ts, exclude_expired=True
141+
)
142+
# Don't clear intervals for a snapshot if it shares the same physical version with prod.
143+
# Otherwise, prod will be affected by what should be a dev operation
144+
if s.name_version not in prod_name_versions
145+
}
146+
147+
# identify the ones that we havent picked up yet, which are the ones that dont exist in any environment
148+
if remaining_snapshot_ids := set(all_matching_non_prod_snapshots).difference(
149+
snapshot_intervals_to_clear
150+
):
151+
# these snapshot id's exist in isolation and may be related to a downstream dependency of the :prod_restatements,
152+
# rather than directly related, so we can't simply look up the interval to clear based on :prod_restatements.
153+
# To figure out the interval that should be cleared, we can match to the existing list based on name
154+
# and conservatively take the widest interval that shows up
155+
snapshot_name_to_widest_interval: t.Dict[str, Interval] = {}
156+
for s_id, clear_request in snapshot_intervals_to_clear.items():
157+
current_start, current_end = snapshot_name_to_widest_interval.get(
158+
s_id.name, clear_request.interval
159+
)
160+
next_start, next_end = clear_request.interval
161+
162+
next_start = min(current_start, next_start)
163+
next_end = max(current_end, next_end)
164+
165+
snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end)
166+
167+
for remaining_snapshot_id in remaining_snapshot_ids:
168+
remaining_snapshot = all_matching_non_prod_snapshots[remaining_snapshot_id]
169+
snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest(
170+
snapshot=remaining_snapshot,
171+
interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name],
172+
)
173+
174+
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
175+
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
176+
# so we only do it if necessary
177+
full_history_restatement_snapshot_ids = [
178+
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
179+
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
180+
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
181+
# So for now, these are not considered
182+
s_id
183+
for s_id, s in snapshot_intervals_to_clear.items()
184+
if s.snapshot.full_history_restatement_only
185+
]
186+
if full_history_restatement_snapshot_ids:
187+
# only load full snapshot records that we havent already loaded
188+
additional_snapshots = state_reader.get_snapshots(
189+
[
190+
s.snapshot_id
191+
for s in full_history_restatement_snapshot_ids
192+
if s.snapshot_id not in loaded_snapshots
193+
]
194+
)
195+
196+
all_snapshots = loaded_snapshots | additional_snapshots
197+
198+
for full_snapshot_id in full_history_restatement_snapshot_ids:
199+
full_snapshot = all_snapshots[full_snapshot_id]
200+
intervals_to_clear = snapshot_intervals_to_clear[full_snapshot_id]
201+
202+
original_start, original_end = intervals_to_clear.interval
203+
204+
# get_removal_interval() widens intervals if necessary
205+
new_interval = full_snapshot.get_removal_interval(
206+
start=original_start, end=original_end
207+
)
208+
209+
intervals_to_clear.interval = new_interval
210+
211+
return snapshot_intervals_to_clear

sqlmesh/core/plan/evaluator.py

Lines changed: 16 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sqlmesh.core.console import Console, get_console
2323
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
2424
from sqlmesh.core.macros import RuntimeStage
25-
from sqlmesh.core.snapshot.definition import Interval, to_view_mapping
25+
from sqlmesh.core.snapshot.definition import to_view_mapping
2626
from sqlmesh.core.plan import stages
2727
from sqlmesh.core.plan.definition import EvaluatablePlan
2828
from sqlmesh.core.scheduler import Scheduler
@@ -33,17 +33,15 @@
3333
SnapshotIntervals,
3434
SnapshotId,
3535
SnapshotInfoLike,
36-
SnapshotTableInfo,
3736
SnapshotCreationFailedError,
38-
SnapshotNameVersion,
3937
)
4038
from sqlmesh.utils import to_snake_case
4139
from sqlmesh.core.state_sync import StateSync
40+
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
4241
from sqlmesh.utils import CorrelationId
4342
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4443
from sqlmesh.utils.errors import PlanError, SQLMeshError
45-
from sqlmesh.utils.dag import DAG
46-
from sqlmesh.utils.date import now
44+
from sqlmesh.utils.date import now, to_timestamp
4745

4846
logger = logging.getLogger(__name__)
4947

@@ -289,7 +287,9 @@ def visit_audit_only_run_stage(
289287
def visit_restatement_stage(
290288
self, stage: stages.RestatementStage, plan: EvaluatablePlan
291289
) -> None:
292-
snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()}
290+
snapshot_intervals_to_restate = {
291+
(s.id_and_version, i) for s, i in stage.snapshot_intervals.items()
292+
}
293293

294294
# Restating intervals on prod plans should mean that the intervals are cleared across
295295
# all environments, not just the version currently in prod
@@ -298,11 +298,16 @@ def visit_restatement_stage(
298298
#
299299
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
300300
snapshot_intervals_to_restate.update(
301-
self._restatement_intervals_across_all_environments(
302-
prod_restatements=plan.restatements,
303-
disable_restatement_models=plan.disabled_restatement_models,
304-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
305-
)
301+
{
302+
(s.snapshot, s.interval)
303+
for s in identify_restatement_intervals_across_snapshot_versions(
304+
state_reader=self.state_sync,
305+
prod_restatements=plan.restatements,
306+
disable_restatement_models=plan.disabled_restatement_models,
307+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
308+
current_ts=to_timestamp(plan.execution_time or now()),
309+
).values()
310+
}
306311
)
307312

308313
self.state_sync.remove_intervals(
@@ -422,97 +427,6 @@ def _demote_snapshots(
422427
on_complete=on_complete,
423428
)
424429

425-
def _restatement_intervals_across_all_environments(
426-
self,
427-
prod_restatements: t.Dict[str, Interval],
428-
disable_restatement_models: t.Set[str],
429-
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
430-
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
431-
"""
432-
Given a map of snapshot names + intervals to restate in prod:
433-
- Look up matching snapshots across all environments (match based on name - regardless of version)
434-
- For each match, also match downstream snapshots while filtering out models that have restatement disabled
435-
- Return all matches mapped to the intervals of the prod snapshot being restated
436-
437-
The goal here is to produce a list of intervals to invalidate across all environments so that a cadence
438-
run in those environments causes the intervals to be repopulated
439-
"""
440-
if not prod_restatements:
441-
return set()
442-
443-
prod_name_versions: t.Set[SnapshotNameVersion] = {
444-
s.name_version for s in loaded_snapshots.values()
445-
}
446-
447-
snapshots_to_restate: t.Dict[SnapshotId, t.Tuple[SnapshotTableInfo, Interval]] = {}
448-
449-
for env_summary in self.state_sync.get_environments_summary():
450-
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
451-
env = self.state_sync.get_environment(env_summary.name)
452-
if not env:
453-
logger.warning("Environment %s not found", env_summary.name)
454-
continue
455-
456-
keyed_snapshots = {s.name: s.table_info for s in env.snapshots}
457-
458-
# We dont just restate matching snapshots, we also have to restate anything downstream of them
459-
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
460-
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})
461-
462-
for restatement, intervals in prod_restatements.items():
463-
if restatement not in keyed_snapshots:
464-
continue
465-
affected_snapshot_names = [
466-
x
467-
for x in ([restatement] + env_dag.downstream(restatement))
468-
if x not in disable_restatement_models
469-
]
470-
snapshots_to_restate.update(
471-
{
472-
keyed_snapshots[a].snapshot_id: (keyed_snapshots[a], intervals)
473-
for a in affected_snapshot_names
474-
# Don't restate a snapshot if it shares the version with a snapshot in prod
475-
if keyed_snapshots[a].name_version not in prod_name_versions
476-
}
477-
)
478-
479-
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
480-
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
481-
# so we only do it if necessary
482-
full_history_restatement_snapshot_ids = [
483-
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
484-
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
485-
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
486-
# So for now, these are not considered
487-
s_id
488-
for s_id, s in snapshots_to_restate.items()
489-
if s[0].full_history_restatement_only
490-
]
491-
if full_history_restatement_snapshot_ids:
492-
# only load full snapshot records that we havent already loaded
493-
additional_snapshots = self.state_sync.get_snapshots(
494-
[
495-
s.snapshot_id
496-
for s in full_history_restatement_snapshot_ids
497-
if s.snapshot_id not in loaded_snapshots
498-
]
499-
)
500-
501-
all_snapshots = loaded_snapshots | additional_snapshots
502-
503-
for full_snapshot_id in full_history_restatement_snapshot_ids:
504-
full_snapshot = all_snapshots[full_snapshot_id]
505-
_, original_intervals = snapshots_to_restate[full_snapshot_id]
506-
original_start, original_end = original_intervals
507-
508-
# get_removal_interval() widens intervals if necessary
509-
new_intervals = full_snapshot.get_removal_interval(
510-
start=original_start, end=original_end
511-
)
512-
snapshots_to_restate[full_snapshot_id] = (full_snapshot.table_info, new_intervals)
513-
514-
return set(snapshots_to_restate.values())
515-
516430
def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None:
517431
snapshots_intervals: t.List[SnapshotIntervals] = []
518432
for snapshot in snapshots:

sqlmesh/core/snapshot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
SnapshotId as SnapshotId,
1212
SnapshotIdBatch as SnapshotIdBatch,
1313
SnapshotIdLike as SnapshotIdLike,
14+
SnapshotIdAndVersionLike as SnapshotIdAndVersionLike,
1415
SnapshotInfoLike as SnapshotInfoLike,
1516
SnapshotIntervals as SnapshotIntervals,
1617
SnapshotNameVersion as SnapshotNameVersion,

0 commit comments

Comments
 (0)