Skip to content

Commit bc326d7

Browse files
committed
PR feedback
1 parent 3f95dc0 commit bc326d7

File tree

6 files changed

+74
-67
lines changed

6 files changed

+74
-67
lines changed

sqlmesh/core/console.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,13 +2275,6 @@ def log_models_updated_during_restatement(
22752275

22762276
self._print(tree)
22772277

2278-
self.log_warning(
2279-
f"\nThe '{environment.name}' environment currently points to [bold]different[/bold] versions of these models, not the versions that just got restated."
2280-
)
2281-
self._print(
2282-
"[yellow]If this is undesirable, please re-run this restatement plan which will apply it to the most recent versions of these models.[/yellow]\n"
2283-
)
2284-
22852278
def log_destructive_change(
22862279
self,
22872280
snapshot_name: str,

sqlmesh/core/plan/evaluator.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
4141
from sqlmesh.utils import CorrelationId
4242
from sqlmesh.utils.concurrency import NodeExecutionFailedError
43-
from sqlmesh.utils.errors import PlanError, SQLMeshError
43+
from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError
4444
from sqlmesh.utils.date import now, to_timestamp
4545

4646
logger = logging.getLogger(__name__)
@@ -311,18 +311,15 @@ def visit_restatement_stage(
311311
# Nothing to do
312312
return
313313

314-
self.state_sync.remove_intervals(
315-
snapshot_intervals=[(s.table_info, s.interval) for s in intervals_to_clear.values()],
316-
remove_shared_versions=plan.is_prod,
317-
)
318-
319314
# While the restatements were being processed, did any of the snapshots being restated get new versions deployed?
320315
# If they did, they will not reflect the data that just got restated, so we need to notify the user
316+
deployed_during_restatement: t.List[
317+
t.Tuple[SnapshotTableInfo, SnapshotTableInfo]
318+
] = [] # tuple of (restated_snapshot, current_prod_snapshot)
319+
321320
if deployed_env := self.state_sync.get_environment(plan.environment.name):
322321
promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots}
323322

324-
deployed_during_restatement: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]] = []
325-
326323
for name in plan.restatements:
327324
snapshot = stage.all_snapshots[name]
328325
version = snapshot.table_info.version
@@ -333,15 +330,32 @@ def visit_restatement_stage(
333330
(snapshot.table_info, prod_snapshot.table_info)
334331
)
335332

336-
if deployed_during_restatement:
337-
self.console.log_models_updated_during_restatement(
338-
deployed_during_restatement,
339-
deployed_env.summary,
340-
plan.environment.naming_info,
341-
self.default_catalog,
342-
)
343-
# note: the plan will automatically fail at the promotion stage with a ConflictingPlanError because the environment was changed by another plan
344-
# so there is no need to explicitly fail the plan here
333+
# we need to not clear the intervals on the snapshots where new versions were deployed while the restatement was running in order to prevent
334+
# subsequent plans from having unexpected intervals to backfill.
335+
filtered_intervals_to_clear = [
336+
(s.snapshot, s.interval)
337+
for s in intervals_to_clear.values()
338+
if s.snapshot.name_version
339+
not in {prod_snapshot.name_version for _, prod_snapshot in deployed_during_restatement}
340+
]
341+
342+
if filtered_intervals_to_clear:
343+
self.state_sync.remove_intervals(
344+
snapshot_intervals=filtered_intervals_to_clear,
345+
remove_shared_versions=plan.is_prod,
346+
)
347+
348+
if deployed_env and deployed_during_restatement:
349+
self.console.log_models_updated_during_restatement(
350+
deployed_during_restatement,
351+
deployed_env.summary,
352+
plan.environment.naming_info,
353+
self.default_catalog,
354+
)
355+
raise ConflictingPlanError(
356+
f"Another plan ({deployed_env.summary.plan_id}) deployed new versions of {len(deployed_during_restatement)} models in the target environment '{plan.environment.name}' while they were being restated by this plan.\n"
357+
"These new versions have *not* had data restated. If they should be restated, please re-apply your plan."
358+
)
345359

346360
def visit_environment_record_update_stage(
347361
self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan

sqlmesh/core/plan/explainer.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sqlmesh.core import constants as c
1212
from sqlmesh.core.console import Console, TerminalConsole, get_console
1313
from sqlmesh.core.environment import EnvironmentNamingInfo
14-
from sqlmesh.core.snapshot.definition import DeployabilityIndex
14+
from sqlmesh.core.snapshot.definition import model_display_name
1515
from sqlmesh.core.plan.common import (
1616
SnapshotIntervalClearRequest,
1717
identify_restatement_intervals_across_snapshot_versions,
@@ -22,9 +22,7 @@
2222
PlanEvaluator,
2323
)
2424
from sqlmesh.core.state_sync import StateReader
25-
from sqlmesh.core.snapshot.definition import (
26-
SnapshotInfoMixin,
27-
)
25+
from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotNameVersionLike
2826
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2927
from sqlmesh.utils.date import to_ts
3028
from sqlmesh.utils.errors import SQLMeshError
@@ -81,10 +79,6 @@ class ExplainableRestatementStage(stages.RestatementStage):
8179
snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
8280
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""
8381

84-
deployability_index: DeployabilityIndex
85-
"""Deployability of those snapshots (which arent necessarily present in the current plan so we cant use the
86-
plan deployability index), used for outputting physical table names"""
87-
8882
@classmethod
8983
def from_restatement_stage(
9084
cls: t.Type[ExplainableRestatementStage],
@@ -99,29 +93,10 @@ def from_restatement_stage(
9993
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
10094
)
10195

102-
snapshot_intervals_to_clear = {}
103-
deployability_index = DeployabilityIndex.all_deployable()
104-
105-
if all_restatement_intervals:
106-
snapshot_intervals_to_clear = {
107-
s_id.name: r for s_id, r in all_restatement_intervals.items()
108-
}
109-
110-
# creating a deployability index over the "snapshot intervals to clear"
111-
# allows us to print the physical names of the tables affected in the console output
112-
# note that we can't use the DeployabilityIndex on the plan because it only includes
113-
# snapshots for the current environment, not across all environments
114-
deployability_index = DeployabilityIndex.create(
115-
snapshots=state_reader.get_snapshots(
116-
[s.snapshot_id for s in snapshot_intervals_to_clear.values()]
117-
),
118-
start=plan.start,
119-
start_override_per_model=plan.start_override_per_model,
120-
)
121-
12296
return cls(
123-
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
124-
deployability_index=deployability_index,
97+
snapshot_intervals_to_clear={
98+
s.snapshot.name: s for s in all_restatement_intervals.values()
99+
},
125100
all_snapshots=stage.all_snapshots,
126101
)
127102

@@ -230,7 +205,7 @@ def visit_restatement_stage(
230205
snapshot_intervals := stage.snapshot_intervals_to_clear
231206
):
232207
for clear_request in snapshot_intervals.values():
233-
display_name = self._display_name(clear_request.table_info)
208+
display_name = self._display_name(clear_request.snapshot)
234209
interval = clear_request.interval
235210
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
236211

@@ -348,15 +323,22 @@ def visit_finalize_environment_stage(
348323

349324
def _display_name(
350325
self,
351-
snapshot: SnapshotInfoMixin,
326+
snapshot: t.Union[SnapshotInfoMixin, SnapshotNameVersionLike],
352327
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
353328
) -> str:
354-
return snapshot.display_name(
355-
environment_naming_info or self.environment_naming_info,
356-
self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
329+
naming_kwargs: t.Any = dict(
330+
environment_naming_info=environment_naming_info or self.environment_naming_info,
331+
default_catalog=self.default_catalog
332+
if self.verbosity < Verbosity.VERY_VERBOSE
333+
else None,
357334
dialect=self.dialect,
358335
)
359336

337+
if isinstance(snapshot, SnapshotInfoMixin):
338+
return snapshot.display_name(**naming_kwargs)
339+
340+
return model_display_name(node_name=snapshot.name, **naming_kwargs)
341+
360342
def _limit_tree(self, tree: Tree) -> Tree:
361343
tree_length = len(tree.children)
362344
if tree_length <= MAX_TREE_LENGTH:

sqlmesh/core/snapshot/definition.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1789,7 +1789,19 @@ def display_name(
17891789
"""
17901790
if snapshot_info_like.is_audit:
17911791
return snapshot_info_like.name
1792-
view_name = exp.to_table(snapshot_info_like.name)
1792+
1793+
return model_display_name(
1794+
snapshot_info_like.name, environment_naming_info, default_catalog, dialect
1795+
)
1796+
1797+
1798+
def model_display_name(
1799+
node_name: str,
1800+
environment_naming_info: EnvironmentNamingInfo,
1801+
default_catalog: t.Optional[str],
1802+
dialect: DialectType = None,
1803+
) -> str:
1804+
view_name = exp.to_table(node_name)
17931805

17941806
catalog = (
17951807
None

tests/core/test_integration.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10400,8 +10400,9 @@ def test_restatement_plan_detects_prod_deployment_during_restatement(tmp_path: P
1040010400
- During restatement, someone else deploys A(dev) to prod, replacing the model that is currently being restated.
1040110401
1040210402
Outcome:
10403-
- The deployment plan for dev -> prod should succeed in deploying the new version
10403+
- The deployment plan for dev -> prod should succeed in deploying the new version of A
1040410404
- The prod restatement plan should fail with a ConflictingPlanError and warn about the model that got updated while undergoing restatement
10405+
- The new version of A should have no intervals cleared. The user needs to rerun the restatement if the intervals should still be cleared
1040510406
"""
1040610407
orig_console = get_console()
1040710408
console = CaptureTerminalConsole()
@@ -10474,6 +10475,7 @@ def entrypoint(evaluator: MacroEvaluator) -> str:
1047410475
ctx.load()
1047510476
plan = ctx.plan(environment="dev", auto_apply=True)
1047610477
assert len(plan.modified_snapshots) == 1
10478+
new_model_a_snapshot_id = list(plan.modified_snapshots)[0]
1047710479

1047810480
# now, trigger a prod restatement plan in a different thread and block it to simulate a long restatement
1047910481
def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue):
@@ -10532,12 +10534,20 @@ def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue):
1053210534

1053310535
plan_error = restatement_plan_future.result()
1053410536
assert isinstance(plan_error, ConflictingPlanError)
10537+
assert "please re-apply your plan" in repr(plan_error)
1053510538

1053610539
output = " ".join(re.split("\s+", console.captured_output, flags=re.UNICODE))
1053710540
assert (
1053810541
f"The following models had new versions deployed in plan '{new_prod.plan_id}' while data was being restated: └── test.model_a"
1053910542
in output
1054010543
)
10541-
assert "please re-run this restatement plan" in output
10544+
10545+
# check that no intervals have been cleared from the model_a currently in prod
10546+
model_a = ctx.state_sync.get_snapshots(snapshot_ids=[new_model_a_snapshot_id])[
10547+
new_model_a_snapshot_id
10548+
]
10549+
assert isinstance(model_a.node, SqlModel)
10550+
assert model_a.node.render_query_or_raise().sql() == 'SELECT 1 AS "changed"'
10551+
assert len(model_a.intervals)
1054210552

1054310553
set_console(orig_console)

tests/core/test_plan_stages.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,6 @@ def test_build_plan_stages_restatement_prod_only(
612612
restatement_stage, state_reader, plan
613613
)
614614
assert not restatement_stage.snapshot_intervals_to_clear
615-
assert (
616-
restatement_stage.deployability_index == DeployabilityIndex.all_deployable()
617-
) # default index
618615

619616
# Verify EnvironmentRecordUpdateStage
620617
assert isinstance(stages[3], EnvironmentRecordUpdateStage)
@@ -774,12 +771,11 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
774771
# note: we only clear the intervals from state for "a" in dev, we leave prod alone
775772
assert restatement_stage.snapshot_intervals_to_clear
776773
assert len(restatement_stage.snapshot_intervals_to_clear) == 1
777-
assert restatement_stage.deployability_index is not None
778774
snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0]
779775
assert isinstance(clear_request, SnapshotIntervalClearRequest)
780776
assert snapshot_name == '"a"'
781777
assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id
782-
assert clear_request.table_info == snapshot_a_dev.table_info
778+
assert clear_request.snapshot == snapshot_a_dev.id_and_version
783779
assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))
784780

785781
# Verify EnvironmentRecordUpdateStage

0 commit comments

Comments
 (0)