Skip to content

Commit 042e3a1

Browse files
committed
Fix: When restating prod, clear intervals across all related snapshots, not just promoted ones
1 parent b82036f commit 042e3a1

File tree

3 files changed

+367
-102
lines changed

3 files changed

+367
-102
lines changed

sqlmesh/core/plan/common.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
2+
import typing as t
3+
import logging
4+
from dataclasses import dataclass, field
25

3-
from sqlmesh.core.snapshot import Snapshot
6+
from sqlmesh.core.state_sync import StateReader
7+
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo, SnapshotNameVersion
8+
from sqlmesh.core.snapshot.definition import Interval
9+
from sqlmesh.utils.dag import DAG
10+
from sqlmesh.utils.date import now_timestamp
11+
12+
logger = logging.getLogger(__name__)
413

514

615
def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool:
@@ -24,3 +33,173 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
2433
# If the partitioning hasn't changed, then we don't need to rebuild
2534
return False
2635
return True
36+
37+
38+
@dataclass
39+
class SnapshotIntervalClearRequest:
40+
# affected snapshot
41+
table_info: SnapshotTableInfo
42+
43+
# which interval to clear
44+
interval: Interval
45+
46+
# which environments this snapshot is currently promoted
47+
# note that this can be empty if the snapshot exists because its ttl has not expired
48+
# but it is not part of any particular environment
49+
environment_names: t.Set[str] = field(default_factory=set)
50+
51+
@property
52+
def snapshot_id(self) -> SnapshotId:
53+
return self.table_info.snapshot_id
54+
55+
@property
56+
def sorted_environment_names(self) -> t.List[str]:
57+
return list(sorted(self.environment_names))
58+
59+
60+
def identify_restatement_intervals_across_snapshot_versions(
61+
state_reader: StateReader,
62+
prod_restatements: t.Dict[str, Interval],
63+
disable_restatement_models: t.Set[str],
64+
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
65+
current_ts: t.Optional[int] = None,
66+
) -> t.Dict[SnapshotId, SnapshotIntervalClearRequest]:
67+
"""
68+
Given a map of snapshot names + intervals to restate in prod:
69+
- Look up matching snapshots (match based on name - regardless of version, to get all versions)
70+
- For each match, also match downstream snapshots in each dev environment while filtering out models that have restatement disabled
71+
- Return a list of all snapshots that are affected + the interval that needs to be cleared for each
72+
73+
The goal here is to produce a list of intervals to invalidate across all dev snapshots so that a subsequent plan or
74+
cadence run in those environments causes the intervals to be repopulated.
75+
"""
76+
if not prod_restatements:
77+
return {}
78+
79+
# Although :loaded_snapshots is sourced from RestatementStage.all_snapshots, since the only time we ever need
80+
# to clear intervals across all environments is for prod, the :loaded_snapshots here are always from prod
81+
prod_name_versions: t.Set[SnapshotNameVersion] = {
82+
s.name_version for s in loaded_snapshots.values()
83+
}
84+
85+
snapshot_intervals_to_clear: t.Dict[SnapshotId, SnapshotIntervalClearRequest] = {}
86+
87+
for env_summary in state_reader.get_environments_summary():
88+
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
89+
env = state_reader.get_environment(env_summary.name)
90+
if not env:
91+
logger.warning("Environment %s not found", env_summary.name)
92+
continue
93+
94+
snapshots_by_name = {s.name: s.table_info for s in env.snapshots}
95+
96+
# We dont just restate matching snapshots, we also have to restate anything downstream of them
97+
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
98+
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})
99+
100+
for restate_snapshot_name, interval in prod_restatements.items():
101+
if restate_snapshot_name not in snapshots_by_name:
102+
# snapshot is not promoted in this environment
103+
continue
104+
105+
affected_snapshot_names = [
106+
x
107+
for x in ([restate_snapshot_name] + env_dag.downstream(restate_snapshot_name))
108+
if x not in disable_restatement_models
109+
]
110+
111+
for affected_snapshot_name in affected_snapshot_names:
112+
affected_snapshot = snapshots_by_name[affected_snapshot_name]
113+
114+
# Don't clear intervals for a dev snapshot if it shares the same physical version with prod.
115+
# Otherwise, prod will be affected by what should be a dev operation
116+
if affected_snapshot.name_version in prod_name_versions:
117+
continue
118+
119+
clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id)
120+
if not clear_request:
121+
clear_request = SnapshotIntervalClearRequest(
122+
table_info=affected_snapshot, interval=interval
123+
)
124+
snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request
125+
126+
clear_request.environment_names |= set([env.name])
127+
128+
# snapshot_intervals_to_clear now contains the entire hierarchy of affected snapshots based
129+
# on building the DAG for each environment and including downstream snapshots
130+
# but, what if there are affected snapshots that arent part of any environment?
131+
unique_snapshot_names = set(snapshot_id.name for snapshot_id in snapshot_intervals_to_clear)
132+
133+
current_ts = current_ts or now_timestamp()
134+
all_matching_snapshot_ids = state_reader.get_snapshot_ids_by_names(
135+
snapshot_names=unique_snapshot_names, current_ts=current_ts, exclude_expired=True
136+
)
137+
138+
# identify the ones that we havent picked up yet, which are the ones that dont exist in any environment
139+
if remaining_snapshot_ids := all_matching_snapshot_ids.difference(snapshot_intervals_to_clear):
140+
# these snapshot id's exist in isolation and may be related to a downstream dependency of the :prod_restatements,
141+
# rather than directly related, so we can't simply look up the interval to clear based on :prod_restatements.
142+
# To figure out the interval that should be cleared, we can match to the existing list based on name
143+
# and conservatively take the widest interval that shows up
144+
snapshot_name_to_widest_interval: t.Dict[str, Interval] = {}
145+
for s_id, clear_request in snapshot_intervals_to_clear.items():
146+
current_start, current_end = snapshot_name_to_widest_interval.get(
147+
s_id.name, clear_request.interval
148+
)
149+
next_start, next_end = clear_request.interval
150+
151+
next_start = min(current_start, next_start)
152+
next_end = max(current_end, next_end)
153+
154+
snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end)
155+
156+
remaining_snapshots = state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids)
157+
for remaining_snapshot_id, remaining_snapshot in remaining_snapshots.items():
158+
# Don't clear intervals for a snapshot if it shares the same physical version with prod.
159+
# Otherwise, prod will be affected by what should be a dev operation
160+
if remaining_snapshot.name_version in prod_name_versions:
161+
continue
162+
163+
snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest(
164+
table_info=remaining_snapshot.table_info,
165+
interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name],
166+
)
167+
168+
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
169+
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
170+
# so we only do it if necessary
171+
full_history_restatement_snapshot_ids = [
172+
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
173+
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
174+
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
175+
# So for now, these are not considered
176+
s_id
177+
for s_id, s in snapshot_intervals_to_clear.items()
178+
if s.table_info.full_history_restatement_only
179+
]
180+
if full_history_restatement_snapshot_ids:
181+
# only load full snapshot records that we havent already loaded
182+
additional_snapshots = state_reader.get_snapshots(
183+
[
184+
s.snapshot_id
185+
for s in full_history_restatement_snapshot_ids
186+
if s.snapshot_id not in loaded_snapshots
187+
]
188+
)
189+
190+
all_snapshots = loaded_snapshots | additional_snapshots
191+
192+
for full_snapshot_id in full_history_restatement_snapshot_ids:
193+
full_snapshot = all_snapshots[full_snapshot_id]
194+
intervals_to_clear = snapshot_intervals_to_clear[full_snapshot_id]
195+
196+
original_start, original_end = intervals_to_clear.interval
197+
198+
# get_removal_interval() widens intervals if necessary
199+
new_interval = full_snapshot.get_removal_interval(
200+
start=original_start, end=original_end
201+
)
202+
203+
intervals_to_clear.interval = new_interval
204+
205+
return snapshot_intervals_to_clear

sqlmesh/core/plan/evaluator.py

Lines changed: 13 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sqlmesh.core.console import Console, get_console
2323
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
2424
from sqlmesh.core.macros import RuntimeStage
25-
from sqlmesh.core.snapshot.definition import Interval, to_view_mapping
25+
from sqlmesh.core.snapshot.definition import to_view_mapping
2626
from sqlmesh.core.plan import stages
2727
from sqlmesh.core.plan.definition import EvaluatablePlan
2828
from sqlmesh.core.scheduler import Scheduler
@@ -33,17 +33,15 @@
3333
SnapshotIntervals,
3434
SnapshotId,
3535
SnapshotInfoLike,
36-
SnapshotTableInfo,
3736
SnapshotCreationFailedError,
38-
SnapshotNameVersion,
3937
)
4038
from sqlmesh.utils import to_snake_case
4139
from sqlmesh.core.state_sync import StateSync
40+
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
4241
from sqlmesh.utils import CorrelationId
4342
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4443
from sqlmesh.utils.errors import PlanError, SQLMeshError
45-
from sqlmesh.utils.dag import DAG
46-
from sqlmesh.utils.date import now
44+
from sqlmesh.utils.date import now, to_timestamp
4745

4846
logger = logging.getLogger(__name__)
4947

@@ -295,11 +293,16 @@ def visit_restatement_stage(
295293
#
296294
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
297295
snapshot_intervals_to_restate.update(
298-
self._restatement_intervals_across_all_environments(
299-
prod_restatements=plan.restatements,
300-
disable_restatement_models=plan.disabled_restatement_models,
301-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
302-
)
296+
{
297+
(s.table_info, s.interval)
298+
for s in identify_restatement_intervals_across_snapshot_versions(
299+
state_reader=self.state_sync,
300+
prod_restatements=plan.restatements,
301+
disable_restatement_models=plan.disabled_restatement_models,
302+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
303+
current_ts=to_timestamp(plan.execution_time or now()),
304+
).values()
305+
}
303306
)
304307

305308
self.state_sync.remove_intervals(
@@ -419,97 +422,6 @@ def _demote_snapshots(
419422
on_complete=on_complete,
420423
)
421424

422-
def _restatement_intervals_across_all_environments(
423-
self,
424-
prod_restatements: t.Dict[str, Interval],
425-
disable_restatement_models: t.Set[str],
426-
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
427-
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
428-
"""
429-
Given a map of snapshot names + intervals to restate in prod:
430-
- Look up matching snapshots across all environments (match based on name - regardless of version)
431-
- For each match, also match downstream snapshots while filtering out models that have restatement disabled
432-
- Return all matches mapped to the intervals of the prod snapshot being restated
433-
434-
The goal here is to produce a list of intervals to invalidate across all environments so that a cadence
435-
run in those environments causes the intervals to be repopulated
436-
"""
437-
if not prod_restatements:
438-
return set()
439-
440-
prod_name_versions: t.Set[SnapshotNameVersion] = {
441-
s.name_version for s in loaded_snapshots.values()
442-
}
443-
444-
snapshots_to_restate: t.Dict[SnapshotId, t.Tuple[SnapshotTableInfo, Interval]] = {}
445-
446-
for env_summary in self.state_sync.get_environments_summary():
447-
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
448-
env = self.state_sync.get_environment(env_summary.name)
449-
if not env:
450-
logger.warning("Environment %s not found", env_summary.name)
451-
continue
452-
453-
keyed_snapshots = {s.name: s.table_info for s in env.snapshots}
454-
455-
# We dont just restate matching snapshots, we also have to restate anything downstream of them
456-
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
457-
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})
458-
459-
for restatement, intervals in prod_restatements.items():
460-
if restatement not in keyed_snapshots:
461-
continue
462-
affected_snapshot_names = [
463-
x
464-
for x in ([restatement] + env_dag.downstream(restatement))
465-
if x not in disable_restatement_models
466-
]
467-
snapshots_to_restate.update(
468-
{
469-
keyed_snapshots[a].snapshot_id: (keyed_snapshots[a], intervals)
470-
for a in affected_snapshot_names
471-
# Don't restate a snapshot if it shares the version with a snapshot in prod
472-
if keyed_snapshots[a].name_version not in prod_name_versions
473-
}
474-
)
475-
476-
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
477-
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
478-
# so we only do it if necessary
479-
full_history_restatement_snapshot_ids = [
480-
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
481-
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
482-
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
483-
# So for now, these are not considered
484-
s_id
485-
for s_id, s in snapshots_to_restate.items()
486-
if s[0].full_history_restatement_only
487-
]
488-
if full_history_restatement_snapshot_ids:
489-
# only load full snapshot records that we havent already loaded
490-
additional_snapshots = self.state_sync.get_snapshots(
491-
[
492-
s.snapshot_id
493-
for s in full_history_restatement_snapshot_ids
494-
if s.snapshot_id not in loaded_snapshots
495-
]
496-
)
497-
498-
all_snapshots = loaded_snapshots | additional_snapshots
499-
500-
for full_snapshot_id in full_history_restatement_snapshot_ids:
501-
full_snapshot = all_snapshots[full_snapshot_id]
502-
_, original_intervals = snapshots_to_restate[full_snapshot_id]
503-
original_start, original_end = original_intervals
504-
505-
# get_removal_interval() widens intervals if necessary
506-
new_intervals = full_snapshot.get_removal_interval(
507-
start=original_start, end=original_end
508-
)
509-
snapshots_to_restate[full_snapshot_id] = (full_snapshot.table_info, new_intervals)
510-
511-
return set(snapshots_to_restate.values())
512-
513425
def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None:
514426
snapshots_intervals: t.List[SnapshotIntervals] = []
515427
for snapshot in snapshots:

0 commit comments

Comments
 (0)