diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index af28f75932..3b6cb1ce07 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -551,6 +551,22 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: """Display list of models that failed during evaluation to the user.""" + @abc.abstractmethod + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + """Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions + + Args: + snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples + environment: which environment got updated while we were restating models + environment_naming_info: how snapshots are named in that :environment (for display name purposes) + default_catalog: the configured default catalog (for display name purposes) + """ + @abc.abstractmethod def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: """Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message.""" @@ -771,6 +787,14 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: pass + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + def log_destructive_change( self, snapshot_name: str, @@ -2225,6 +2249,30 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: for node_name, msg in error_messages.items(): self._print(f" [red]{node_name}[/red]\n\n{msg}") + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + ) -> None: + if snapshots: + tree = Tree( + f"[yellow]The following models had new versions deployed while data was being restated:[/yellow]" + ) + + for restated_snapshot, updated_snapshot in snapshots: + display_name = restated_snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + current_branch = tree.add(display_name) + current_branch.add(f"restated version: '{restated_snapshot.version}'") + current_branch.add(f"currently active version: '{updated_snapshot.version}'") + + self._print(tree) + self._print("") # newline spacer + def log_destructive_change( self, snapshot_name: str, diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 79053e018b..03ecb770bf 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -22,7 +22,7 @@ from sqlmesh.core.console import Console, get_console from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements from sqlmesh.core.macros import RuntimeStage -from sqlmesh.core.snapshot.definition import to_view_mapping +from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo from sqlmesh.core.plan import stages from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.scheduler import Scheduler @@ -40,7 +40,7 @@ from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions from sqlmesh.utils import CorrelationId from sqlmesh.utils.concurrency import NodeExecutionFailedError -from sqlmesh.utils.errors import PlanError, SQLMeshError +from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError from sqlmesh.utils.date import now, to_timestamp logger = logging.getLogger(__name__) @@ -287,34 +287,78 @@ def visit_audit_only_run_stage( def visit_restatement_stage( self, stage: stages.RestatementStage, plan: EvaluatablePlan ) -> None: - snapshot_intervals_to_restate = { - (s.id_and_version, i) for s, i in stage.snapshot_intervals.items() - } - - # Restating intervals on prod plans should mean that the intervals are cleared across - # all environments, not just the version currently in prod - # This ensures that work done in dev environments can still be promoted to prod - # by forcing dev environments to re-run intervals that changed in prod + # Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled + # (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments. + # + # This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to + # re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing) + # + # It also means that any new dev environments created while this restatement plan was running also get the + # correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan + # was created, which could have been several hours ago if there was a lot of data to restate. # # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod - snapshot_intervals_to_restate.update( - { - (s.snapshot, s.interval) - for s in identify_restatement_intervals_across_snapshot_versions( - state_reader=self.state_sync, - prod_restatements=plan.restatements, - disable_restatement_models=plan.disabled_restatement_models, - loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, - current_ts=to_timestamp(plan.execution_time or now()), - ).values() - } - ) - self.state_sync.remove_intervals( - snapshot_intervals=list(snapshot_intervals_to_restate), - remove_shared_versions=plan.is_prod, + intervals_to_clear = identify_restatement_intervals_across_snapshot_versions( + state_reader=self.state_sync, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + current_ts=to_timestamp(plan.execution_time or now()), ) + if not intervals_to_clear: + # Nothing to do + return + + # While the restatements were being processed, did any of the snapshots being restated get new versions deployed? + # If they did, they will not reflect the data that just got restated, so we need to notify the user + deployed_during_restatement: t.Dict[ + str, t.Tuple[SnapshotTableInfo, SnapshotTableInfo] + ] = {} # tuple of (restated_snapshot, current_prod_snapshot) + + if deployed_env := self.state_sync.get_environment(plan.environment.name): + promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots} + + for name in plan.restatements: + snapshot = stage.all_snapshots[name] + version = snapshot.table_info.version + if ( + prod_snapshot := promoted_snapshots_by_name.get(name) + ) and prod_snapshot.version != version: + deployed_during_restatement[name] = ( + snapshot.table_info, + prod_snapshot.table_info, + ) + + # we need to *not* clear the intervals on the snapshots where new versions were deployed while the restatement was running in order to prevent + # subsequent plans from having unexpected intervals to backfill. + # we instead list the affected models and abort the plan with an error so the user can decide what to do + # (either re-attempt the restatement plan or leave things as they are) + filtered_intervals_to_clear = [ + (s.snapshot, s.interval) + for s in intervals_to_clear.values() + if s.snapshot.name not in deployed_during_restatement + ] + + if filtered_intervals_to_clear: + # We still clear intervals in other envs for models that were successfully restated without having new versions promoted during restatement + self.state_sync.remove_intervals( + snapshot_intervals=filtered_intervals_to_clear, + remove_shared_versions=plan.is_prod, + ) + + if deployed_env and deployed_during_restatement: + self.console.log_models_updated_during_restatement( + list(deployed_during_restatement.values()), + plan.environment.naming_info, + self.default_catalog, + ) + raise ConflictingPlanError( + f"Another plan ({deployed_env.summary.plan_id}) deployed new versions of {len(deployed_during_restatement)} models in the target environment '{plan.environment.name}' while they were being restated by this plan.\n" + "Please re-apply your plan if these new versions should be restated." + ) + def visit_environment_record_update_stage( self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan ) -> None: diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index ee829aeac1..b722d00d58 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import abc import typing as t import logging +from dataclasses import dataclass from rich.console import Console as RichConsole from rich.tree import Tree @@ -8,15 +11,17 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, TerminalConsole, get_console from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.plan.common import ( + SnapshotIntervalClearRequest, + identify_restatement_intervals_across_snapshot_versions, +) from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals from sqlmesh.core.plan import stages from sqlmesh.core.plan.evaluator import ( PlanEvaluator, ) from sqlmesh.core.state_sync import StateReader -from sqlmesh.core.snapshot.definition import ( - SnapshotInfoMixin, -) +from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion from sqlmesh.utils import Verbosity, rich as srich, to_snake_case from sqlmesh.utils.date import to_ts from sqlmesh.utils.errors import SQLMeshError @@ -45,6 +50,15 @@ def evaluate( explainer_console = _get_explainer_console( self.console, plan.environment, self.default_catalog ) + + # add extra metadata that's only needed at this point for better --explain output + plan_stages = [ + ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan) + if isinstance(stage, stages.RestatementStage) + else stage + for stage in plan_stages + ] + explainer_console.explain(plan_stages) @@ -54,6 +68,38 @@ def explain(self, stages: t.List[stages.PlanStage]) -> None: pass +@dataclass +class ExplainableRestatementStage(stages.RestatementStage): + """ + This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication + of what might happen when they ask for the plan to be explained + """ + + snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest] + """Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name""" + + @classmethod + def from_restatement_stage( + cls: t.Type[ExplainableRestatementStage], + stage: stages.RestatementStage, + state_reader: StateReader, + plan: EvaluatablePlan, + ) -> ExplainableRestatementStage: + all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions( + state_reader=state_reader, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + ) + + return cls( + snapshot_intervals_to_clear={ + s.snapshot.name: s for s in all_restatement_intervals.values() + }, + all_snapshots=stage.all_snapshots, + ) + + MAX_TREE_LENGTH = 10 @@ -146,11 +192,22 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree: tree.add(display_name) return tree - def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree: + def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree: + return self.visit_restatement_stage(stage) + + def visit_restatement_stage( + self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage] + ) -> Tree: tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]") - for snapshot_table_info, interval in stage.snapshot_intervals.items(): - display_name = self._display_name(snapshot_table_info) - tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]") + + if isinstance(stage, ExplainableRestatementStage) and ( + snapshot_intervals := stage.snapshot_intervals_to_clear + ): + for clear_request in snapshot_intervals.values(): + display_name = self._display_name(clear_request.snapshot) + interval = clear_request.interval + tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]") + return tree def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree: @@ -265,12 +322,14 @@ def visit_finalize_environment_stage( def _display_name( self, - snapshot: SnapshotInfoMixin, + snapshot: t.Union[SnapshotInfoMixin, SnapshotIdAndVersion], environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, ) -> str: return snapshot.display_name( - environment_naming_info or self.environment_naming_info, - self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + environment_naming_info=environment_naming_info or self.environment_naming_info, + default_catalog=self.default_catalog + if self.verbosity < Verbosity.VERY_VERBOSE + else None, dialect=self.dialect, ) diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 91c8c6ff14..0d829a6739 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -12,7 +12,6 @@ Snapshot, SnapshotTableInfo, SnapshotId, - Interval, ) @@ -98,14 +97,19 @@ class AuditOnlyRunStage: @dataclass class RestatementStage: - """Restate intervals for given snapshots. + """Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod. + + This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed. + The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have + been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time. + + Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments. Args: - snapshot_intervals: Intervals to restate. - all_snapshots: All snapshots in the plan by name. + all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their + intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots """ - snapshot_intervals: t.Dict[SnapshotTableInfo, Interval] all_snapshots: t.Dict[str, Snapshot] @@ -321,10 +325,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: if audit_only_snapshots: stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values()))) - restatement_stage = self._get_restatement_stage(plan, snapshots_by_name) - if restatement_stage: - stages.append(restatement_stage) - if missing_intervals_before_promote: stages.append( BackfillStage( @@ -349,6 +349,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: ) ) + # note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage) + # needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them. + # in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data. + # we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces + # data for existing intervals and does not produce new ones + restatement_stage = self._get_restatement_stage(plan, snapshots_by_name) + if restatement_stage: + stages.append(restatement_stage) + stages.append( EnvironmentRecordUpdateStage( no_gaps_snapshot_names={s.name for s in before_promote_snapshots} @@ -443,15 +452,12 @@ def _get_after_all_stage( def _get_restatement_stage( self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot] ) -> t.Optional[RestatementStage]: - snapshot_intervals_to_restate = {} - for name, interval in plan.restatements.items(): - restated_snapshot = snapshots_by_name[name] - restated_snapshot.remove_interval(interval) - snapshot_intervals_to_restate[restated_snapshot.table_info] = interval - if not snapshot_intervals_to_restate or plan.is_dev: + if not plan.restatements or plan.is_dev: + # The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod return None + return RestatementStage( - snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name + all_snapshots=snapshots_by_name, ) def _get_physical_layer_update_stage( diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index c17e94be10..9522366721 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -638,6 +638,16 @@ def dev_version(self) -> str: def model_kind_name(self) -> t.Optional[ModelKindName]: return self.kind_name_ + def display_name( + self, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + dialect: DialectType = None, + ) -> str: + return model_display_name( + self.name, environment_naming_info, default_catalog, dialect=dialect + ) + class Snapshot(PydanticModel, SnapshotInfoMixin): """A snapshot represents a node at a certain point in time. @@ -1788,7 +1798,19 @@ def display_name( """ if snapshot_info_like.is_audit: return snapshot_info_like.name - view_name = exp.to_table(snapshot_info_like.name) + + return model_display_name( + snapshot_info_like.name, environment_naming_info, default_catalog, dialect + ) + + +def model_display_name( + node_name: str, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + dialect: DialectType = None, +) -> str: + view_name = exp.to_table(node_name) catalog = ( None diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index ef7c59ea7d..0fad472cd5 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -14,7 +14,7 @@ import pytest from pytest import MonkeyPatch from pathlib import Path -from sqlmesh.core.console import set_console, get_console, TerminalConsole +from sqlmesh.core.console import set_console, get_console, TerminalConsole, CaptureTerminalConsole from sqlmesh.core.config.naming import NameInferenceConfig from sqlmesh.core.model.common import ParsableSql from sqlmesh.utils.concurrency import NodeExecutionFailedError @@ -24,7 +24,9 @@ from sqlglot.expressions import DataType import re from IPython.utils.capture import capture_output - +from concurrent.futures import ThreadPoolExecutor, TimeoutError +import time +import queue from sqlmesh import CustomMaterialization from sqlmesh.cli.project_init import init_example_project @@ -72,7 +74,13 @@ SnapshotTableInfo, ) from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp -from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError +from sqlmesh.utils.errors import ( + NoChangesPlanError, + SQLMeshError, + PlanError, + ConfigError, + ConflictingPlanError, +) from sqlmesh.utils.pydantic import validate_string from tests.conftest import DuckDBMetadata, SushiDataValidator from sqlmesh.utils import CorrelationId @@ -10181,3 +10189,405 @@ def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Pa assert test_result.testsRun == len(test_result.successes) context.close() + + +def test_restatement_plan_interval_external_visibility(tmp_path: Path): + """ + Scenario: + - `prod` environment exists, models A <- B + - `dev` environment created, models A <- B(dev) <- C (dev) + - Restatement plan is triggered against `prod` for model A + - During restatement, a new dev environment `dev_2` is created with a new version of B(dev_2) + + Outcome: + - At no point are the prod_intervals considered "missing" from state for A + - The intervals for B(dev) and C(dev) are cleared + - The intervals for B(dev_2) are also cleared even though the environment didnt exist at the time the plan was started, + because they are based on the data from a partially restated version of A + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + lock_file_path = tmp_path / "test.lock" # python model blocks while this file is present + + evaluation_lock_file_path = ( + tmp_path / "evaluation.lock" + ) # python model creates this file if it's in the wait loop and deletes it once done + + # Note: to make execution block so we can test stuff, we use a Python model that blocks until it no longer detects the presence of a file + (models_dir / "model_a.py").write_text(f""" +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "test.model_a", + is_sql=True, + kind="FULL" +) +def entrypoint(evaluator: MacroEvaluator) -> str: + from pathlib import Path + import time + + if evaluator.runtime_stage == 'evaluating': + while True: + if Path("{str(lock_file_path)}").exists(): + Path("{str(evaluation_lock_file_path)}").touch() + print("lock exists; sleeping") + time.sleep(2) + else: + Path("{str(evaluation_lock_file_path)}").unlink(missing_ok=True) + break + + return "select 'model_a' as m" +""") + + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb from test.model_a as a + """) + + config = Config( + gateways={ + "": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "db.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=[tmp_path], config=config) + + ctx.plan(environment="prod", auto_apply=True) + + assert len(ctx.snapshots) == 2 + assert all(s.intervals for s in ctx.snapshots.values()) + + prod_model_a_snapshot_id = ctx.snapshots['"db"."test"."model_a"'].snapshot_id + prod_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + + # dev models + # new version of B + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb, 'dev' as dev_version from test.model_a as a + """) + + # add C + (models_dir / "model_c.sql").write_text(""" + MODEL ( + name test.model_c, + kind FULL + ); + + select b.*, 'model_c' as mc from test.model_b as b + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + + dev_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + dev_model_c_snapshot_id = ctx.snapshots['"db"."test"."model_c"'].snapshot_id + + assert dev_model_b_snapshot_id != prod_model_b_snapshot_id + + # now, we restate A in prod but touch the lockfile so it hangs during evaluation + # we also have to do it in its own thread due to the hang + lock_file_path.touch() + + def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): + q.put("thread_started") + + # give this thread its own Context object to prevent segfaulting the Python interpreter + restatement_ctx = Context(paths=[tmp_path], config=config) + + # dev2 not present before the restatement plan starts + assert restatement_ctx.state_sync.get_environment("dev2") is None + + q.put("plan_started") + plan = restatement_ctx.plan( + environment="prod", restate_models=['"db"."test"."model_a"'], auto_apply=True + ) + q.put("plan_completed") + + # dev2 was created during the restatement plan + assert restatement_ctx.state_sync.get_environment("dev2") is not None + + return plan + + executor = ThreadPoolExecutor() + q: queue.Queue = queue.Queue() + restatement_plan_future = executor.submit(_run_restatement_plan, tmp_path, config, q) + assert q.get() == "thread_started" + + try: + if e := restatement_plan_future.exception(timeout=1): + # abort early if the plan thread threw an exception + raise e + except TimeoutError: + # that's ok, we dont actually expect the plan to have finished in 1 second + pass + + # while that restatement is running, we can simulate another process and check that it sees no empty intervals + assert q.get() == "plan_started" + + # dont check for potentially missing intervals until the plan is in the evaluation loop + attempts = 0 + while not evaluation_lock_file_path.exists(): + time.sleep(2) + attempts += 1 + if attempts > 10: + raise ValueError("Gave up waiting for evaluation loop") + + ctx.clear_caches() # get rid of the file cache so that data is re-fetched from state + prod_models_from_state = ctx.state_sync.get_snapshots( + snapshot_ids=[prod_model_a_snapshot_id, prod_model_b_snapshot_id] + ) + + # prod intervals should be present still + assert all(m.intervals for m in prod_models_from_state.values()) + + # so should dev intervals since prod restatement is still running + assert all(m.intervals for m in ctx.snapshots.values()) + + # now, lets create a new dev environment "dev2", while the prod restatement plan is still running, + # that changes model_b while still being based on the original version of model_a + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb, 'dev2' as dev_version from test.model_a as a + """) + ctx.load() + ctx.plan(environment="dev2", auto_apply=True) + + dev2_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + assert dev2_model_b_snapshot_id != dev_model_b_snapshot_id + assert dev2_model_b_snapshot_id != prod_model_b_snapshot_id + + # as at this point, everything still has intervals + ctx.clear_caches() + assert all( + s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + prod_model_a_snapshot_id, + prod_model_b_snapshot_id, + dev_model_b_snapshot_id, + dev_model_c_snapshot_id, + dev2_model_b_snapshot_id, + ] + ).values() + ) + + # now, we finally let that restatement plan complete + # first, verify it's still blocked where it should be + assert not restatement_plan_future.done() + + lock_file_path.unlink() # remove lock file, plan should be able to proceed now + + if e := restatement_plan_future.exception(): # blocks until future complete + raise e + + assert restatement_plan_future.result() + assert q.get() == "plan_completed" + + ctx.clear_caches() + + # check that intervals in prod are present + assert all( + s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + prod_model_a_snapshot_id, + prod_model_b_snapshot_id, + ] + ).values() + ) + + # check that intervals in dev have been cleared, including the dev2 env that + # was created after the restatement plan started + assert all( + not s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + dev_model_b_snapshot_id, + dev_model_c_snapshot_id, + dev2_model_b_snapshot_id, + ] + ).values() + ) + + executor.shutdown() + + +def test_restatement_plan_detects_prod_deployment_during_restatement(tmp_path: Path): + """ + Scenario: + - `prod` environment exists, model A + - `dev` environment created, model A(dev) + - Restatement plan is triggered against `prod` for model A + - During restatement, someone else deploys A(dev) to prod, replacing the model that is currently being restated. + + Outcome: + - The deployment plan for dev -> prod should succeed in deploying the new version of A + - The prod restatement plan should fail with a ConflictingPlanError and warn about the model that got updated while undergoing restatement + - The new version of A should have no intervals cleared. The user needs to rerun the restatement if the intervals should still be cleared + """ + orig_console = get_console() + console = CaptureTerminalConsole() + set_console(console) + + models_dir = tmp_path / "models" + models_dir.mkdir() + + lock_file_path = tmp_path / "test.lock" # python model blocks while this file is present + + evaluation_lock_file_path = ( + tmp_path / "evaluation.lock" + ) # python model creates this file if it's in the wait loop and deletes it once done + + # Note: to make execution block so we can test stuff, we use a Python model that blocks until it no longer detects the presence of a file + (models_dir / "model_a.py").write_text(f""" +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "test.model_a", + is_sql=True, + kind="FULL" +) +def entrypoint(evaluator: MacroEvaluator) -> str: + from pathlib import Path + import time + + if evaluator.runtime_stage == 'evaluating': + while True: + if Path("{str(lock_file_path)}").exists(): + Path("{str(evaluation_lock_file_path)}").touch() + print("lock exists; sleeping") + time.sleep(2) + else: + Path("{str(evaluation_lock_file_path)}").unlink(missing_ok=True) + break + + return "select 'model_a' as m" +""") + + config = Config( + gateways={ + "": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "db.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=[tmp_path], config=config) + + # create prod + ctx.plan(environment="prod", auto_apply=True) + original_prod = ctx.state_sync.get_environment("prod") + assert original_prod + + # update model_a for dev + (models_dir / "model_a.py").unlink() + (models_dir / "model_a.sql").write_text(""" + MODEL ( + name test.model_a, + kind FULL + ); + + select 1 as changed + """) + + # create dev + ctx.load() + plan = ctx.plan(environment="dev", auto_apply=True) + assert len(plan.modified_snapshots) == 1 + new_model_a_snapshot_id = list(plan.modified_snapshots)[0] + + # now, trigger a prod restatement plan in a different thread and block it to simulate a long restatement + def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): + q.put("thread_started") + + # give this thread its own Context object to prevent segfaulting the Python interpreter + restatement_ctx = Context(paths=[tmp_path], config=config) + + # ensure dev is present before the restatement plan starts + assert restatement_ctx.state_sync.get_environment("dev") is not None + + q.put("plan_started") + expected_error = None + try: + restatement_ctx.plan( + environment="prod", restate_models=['"db"."test"."model_a"'], auto_apply=True + ) + except ConflictingPlanError as e: + expected_error = e + + q.put("plan_completed") + return expected_error + + executor = ThreadPoolExecutor() + q: queue.Queue = queue.Queue() + lock_file_path.touch() + + restatement_plan_future = executor.submit(_run_restatement_plan, tmp_path, config, q) + restatement_plan_future.add_done_callback(lambda _: executor.shutdown()) + + assert q.get() == "thread_started" + + try: + if e := restatement_plan_future.exception(timeout=1): + # abort early if the plan thread threw an exception + raise e + except TimeoutError: + # that's ok, we dont actually expect the plan to have finished in 1 second + pass + + assert q.get() == "plan_started" + + # ok, now the prod restatement plan is running, let's deploy dev to prod + ctx.plan(environment="prod", auto_apply=True) + + new_prod = ctx.state_sync.get_environment("prod") + assert new_prod + assert new_prod.plan_id != original_prod.plan_id + assert new_prod.previous_plan_id == original_prod.plan_id + + # new prod is deployed but restatement plan is still running + assert not restatement_plan_future.done() + + # allow restatement plan to complete + lock_file_path.unlink() + + plan_error = restatement_plan_future.result() + assert isinstance(plan_error, ConflictingPlanError) + assert "please re-apply your plan" in repr(plan_error).lower() + + output = " ".join(re.split("\s+", console.captured_output, flags=re.UNICODE)) + assert ( + f"The following models had new versions deployed while data was being restated: └── test.model_a" + in output + ) + + # check that no intervals have been cleared from the model_a currently in prod + model_a = ctx.state_sync.get_snapshots(snapshot_ids=[new_model_a_snapshot_id])[ + new_model_a_snapshot_id + ] + assert isinstance(model_a.node, SqlModel) + assert model_a.node.render_query_or_raise().sql() == 'SELECT 1 AS "changed"' + assert len(model_a.intervals) + + set_console(orig_console) diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index 744c7d18bf..4ada7d458d 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -6,6 +6,7 @@ from sqlmesh.core.config import EnvironmentSuffixTarget from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.model import SqlModel, ModelKindName +from sqlmesh.core.plan.common import SnapshotIntervalClearRequest from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.plan.stages import ( build_plan_stages, @@ -23,11 +24,13 @@ FinalizeEnvironmentStage, UnpauseStage, ) +from sqlmesh.core.plan.explainer import ExplainableRestatementStage from sqlmesh.core.snapshot.definition import ( SnapshotChangeCategory, DeployabilityIndex, Snapshot, SnapshotId, + SnapshotIdLike, ) from sqlmesh.core.state_sync import StateReader from sqlmesh.core.environment import Environment, EnvironmentStatements @@ -499,15 +502,29 @@ def test_build_plan_stages_basic_no_backfill( assert isinstance(stages[7], FinalizeEnvironmentStage) -def test_build_plan_stages_restatement( +def test_build_plan_stages_restatement_prod_only( snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture ) -> None: + """ + Scenario: + - Prod restatement triggered in a project with no dev environments + + Expected Outcome: + - Plan still contains a RestatementStage in case a dev environment was + created during restatement + """ + # Mock state reader to return existing snapshots and environment state_reader = mocker.Mock(spec=StateReader) state_reader.get_snapshots.return_value = { snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, } + state_reader.get_snapshots_by_names.return_value = { + snapshot_a.id_and_version, + snapshot_b.id_and_version, + } + existing_environment = Environment( name="prod", snapshots=[snapshot_a.table_info, snapshot_b.table_info], @@ -518,7 +535,9 @@ def test_build_plan_stages_restatement( promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], finalized_ts=to_timestamp("2023-01-02"), ) + state_reader.get_environment.return_value = existing_environment + state_reader.get_environments_summary.return_value = [existing_environment.summary] environment = Environment( name="prod", @@ -577,17 +596,164 @@ def test_build_plan_stages_restatement( snapshot_b.snapshot_id, } - # Verify RestatementStage - restatement_stage = stages[1] + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 2 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_backfill_interval = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + for intervals in backfill_stage.snapshot_to_intervals.values(): + assert intervals == expected_backfill_interval + + # Verify RestatementStage exists but is empty + restatement_stage = stages[2] assert isinstance(restatement_stage, RestatementStage) - assert len(restatement_stage.snapshot_intervals) == 2 - expected_interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) - for snapshot_info, interval in restatement_stage.snapshot_intervals.items(): - assert interval == expected_interval - assert snapshot_info.name in ('"a"', '"b"') + restatement_stage = ExplainableRestatementStage.from_restatement_stage( + restatement_stage, state_reader, plan + ) + assert not restatement_stage.snapshot_intervals_to_clear + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + +def test_build_plan_stages_restatement_prod_identifies_dev_intervals( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + make_snapshot: t.Callable[..., Snapshot], + mocker: MockerFixture, +) -> None: + """ + Scenario: + - Prod restatement triggered in a project with a dev environment + - The dev environment contains a different physical version of the affected model + + Expected Outcome: + - Plan contains a RestatementStage that highlights the affected dev version + """ + # Dev version of snapshot_a, same name but different version + snapshot_a_dev = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, changed, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot_a_dev.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot_a_dev.snapshot_id != snapshot_a.snapshot_id + assert snapshot_a_dev.table_info != snapshot_a.table_info + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + snapshots_in_state = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_a_dev.snapshot_id: snapshot_a_dev, + } + + def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): + return { + k: v + for k, v in snapshots_in_state.items() + if k in {s.snapshot_id for s in snapshot_ids} + } + + state_reader.get_snapshots.side_effect = _get_snapshots + state_reader.get_snapshots_by_names.return_value = set() + + existing_prod_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + # dev has new version of snapshot_a but same version of snapshot_b + existing_dev_environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.side_effect = ( + lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + ) + state_reader.get_environments_summary.return_value = [ + existing_prod_environment.summary, + existing_dev_environment.summary, + ] + + environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + '"b"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } # Verify BackfillStage - backfill_stage = stages[2] + backfill_stage = stages[1] assert isinstance(backfill_stage, BackfillStage) assert len(backfill_stage.snapshot_to_intervals) == 2 assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() @@ -595,6 +761,23 @@ def test_build_plan_stages_restatement( for intervals in backfill_stage.snapshot_to_intervals.values(): assert intervals == expected_backfill_interval + # Verify RestatementStage + restatement_stage = stages[2] + assert isinstance(restatement_stage, RestatementStage) + restatement_stage = ExplainableRestatementStage.from_restatement_stage( + restatement_stage, state_reader, plan + ) + + # note: we only clear the intervals from state for "a" in dev, we leave prod alone + assert restatement_stage.snapshot_intervals_to_clear + assert len(restatement_stage.snapshot_intervals_to_clear) == 1 + snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0] + assert isinstance(clear_request, SnapshotIntervalClearRequest) + assert snapshot_name == '"a"' + assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id + assert clear_request.snapshot == snapshot_a_dev.id_and_version + assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + # Verify EnvironmentRecordUpdateStage assert isinstance(stages[3], EnvironmentRecordUpdateStage) @@ -602,6 +785,155 @@ def test_build_plan_stages_restatement( assert isinstance(stages[4], FinalizeEnvironmentStage) +def test_build_plan_stages_restatement_dev_does_not_clear_intervals( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + make_snapshot: t.Callable[..., Snapshot], + mocker: MockerFixture, +) -> None: + """ + Scenario: + - Restatement triggered against the dev environment + + Expected Outcome: + - BackfillStage only touches models in that dev environment + - Plan does not contain a RestatementStage because making changes in dev doesnt mean we need + to clear intervals from other environments + """ + # Dev version of snapshot_a, same name but different version + snapshot_a_dev = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, changed, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot_a_dev.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot_a_dev.snapshot_id != snapshot_a.snapshot_id + assert snapshot_a_dev.table_info != snapshot_a.table_info + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + snapshots_in_state = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_a_dev.snapshot_id: snapshot_a_dev, + } + state_reader.get_snapshots.side_effect = lambda snapshot_info_like: { + k: v + for k, v in snapshots_in_state.items() + if k in [sil.snapshot_id for sil in snapshot_info_like] + } + + # prod has snapshot_a, snapshot_b + existing_prod_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_prod_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + # dev has new version of snapshot_a + existing_dev_environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_dev_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.side_effect = ( + lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + ) + state_reader.get_environments_summary.return_value = [ + existing_prod_environment.summary, + existing_dev_environment.summary, + ] + + environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_dev_plan", + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify no RestatementStage + assert not any(s for s in stages if isinstance(s, RestatementStage)) + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 1 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a_dev.snapshot_id, + } + + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 1 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + backfill_snapshot, backfill_intervals = list(backfill_stage.snapshot_to_intervals.items())[0] + assert backfill_snapshot.snapshot_id == snapshot_a_dev.snapshot_id + assert backfill_intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + + # Verify VirtualLayerUpdateStage (all non-prod plans get this regardless) + assert isinstance(stages[3], VirtualLayerUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + def test_build_plan_stages_forward_only( snapshot_a: Snapshot, snapshot_b: Snapshot, make_snapshot, mocker: MockerFixture ) -> None: @@ -1686,6 +2018,7 @@ def test_adjust_intervals_restatement_removal( state_reader.refresh_snapshot_intervals = mocker.Mock() state_reader.get_snapshots.return_value = {} state_reader.get_environment.return_value = None + state_reader.get_environments_summary.return_value = [] environment = Environment( snapshots=[snapshot_a.table_info, snapshot_b.table_info], @@ -1738,8 +2071,6 @@ def test_adjust_intervals_restatement_removal( restatement_stages = [stage for stage in stages if isinstance(stage, RestatementStage)] assert len(restatement_stages) == 1 - restatement_stage = restatement_stages[0] - assert len(restatement_stage.snapshot_intervals) == 2 backfill_stages = [stage for stage in stages if isinstance(stage, BackfillStage)] assert len(backfill_stages) == 1