Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def finalize(self, environment: Environment) -> None:
def unpause_snapshots(
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
) -> None:
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state)
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt)

def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
self.environment_state.invalidate_environment(name, protect_prod)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/db/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _migrate_environment_rows(
if updated_prod_environment:
try:
self.snapshot_state.unpause_snapshots(
updated_prod_environment.snapshots, now_timestamp(), self.interval_state
updated_prod_environment.snapshots, now_timestamp()
)
except Exception:
logger.warning("Failed to unpause migrated snapshots", exc_info=True)
Expand Down
145 changes: 57 additions & 88 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.state_sync.db.utils import (
snapshot_name_filter,
snapshot_name_version_filter,
snapshot_id_filter,
fetchone,
Expand All @@ -32,15 +33,13 @@
SnapshotChangeCategory,
)
from sqlmesh.utils.migration import index_text_type, blob_text_type
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
from sqlmesh.utils.pydantic import PydanticModel
from sqlmesh.utils import unique

if t.TYPE_CHECKING:
import pandas as pd

from sqlmesh.core.state_sync.db.interval import IntervalState


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +69,7 @@ def __init__(
"unpaused_ts": exp.DataType.build("bigint"),
"ttl_ms": exp.DataType.build("bigint"),
"unrestorable": exp.DataType.build("boolean"),
"forward_only": exp.DataType.build("boolean"),
}

self._auto_restatement_columns_to_types = {
Expand Down Expand Up @@ -112,84 +112,48 @@ def unpause_snapshots(
self,
snapshots: t.Collection[SnapshotInfoLike],
unpaused_dt: TimeLike,
interval_state: IntervalState,
) -> None:
"""Unpauses given snapshots while pausing all other snapshots that share the same version.

Args:
snapshots: The snapshots to unpause.
unpaused_dt: The timestamp to unpause the snapshots at.
interval_state: The interval state to use to remove intervals when needed.
"""
current_ts = now()

target_snapshot_ids = {s.snapshot_id for s in snapshots}
same_version_snapshots = self._get_snapshots_with_same_version(
snapshots, lock_for_update=True
)
target_snapshots_by_version = {
(s.name, s.version): s
for s in same_version_snapshots
if s.snapshot_id in target_snapshot_ids
}
unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[str]] = defaultdict(list)

unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list)
paused_snapshots: t.List[SnapshotId] = []
unrestorable_snapshots: t.List[SnapshotId] = []

for snapshot in same_version_snapshots:
is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids
if is_target_snapshot and not snapshot.unpaused_ts:
logger.info("Unpausing snapshot %s", snapshot.snapshot_id)
snapshot.set_unpaused_ts(unpaused_dt)
assert snapshot.unpaused_ts is not None
unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id)
elif not is_target_snapshot:
target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)]
if (
target_snapshot.normalized_effective_from_ts
and not target_snapshot.disable_restatement
):
# Making sure that there are no overlapping intervals.
effective_from_ts = target_snapshot.normalized_effective_from_ts
logger.info(
"Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s",
target_snapshot.effective_from,
snapshot.snapshot_id,
target_snapshot.snapshot_id,
)
full_snapshot = snapshot.full_snapshot
interval_state.remove_intervals(
[
(
full_snapshot,
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
)
]
)

if snapshot.unpaused_ts:
logger.info("Pausing snapshot %s", snapshot.snapshot_id)
snapshot.set_unpaused_ts(None)
paused_snapshots.append(snapshot.snapshot_id)
for snapshot in snapshots:
# We need to mark all other snapshots that have opposite forward only status as unrestorable
Comment thread
izeigerman marked this conversation as resolved.
Outdated
unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append(
snapshot.name
)

if not snapshot.unrestorable and (
(target_snapshot.is_forward_only and not snapshot.is_forward_only)
or (snapshot.is_forward_only and not target_snapshot.is_forward_only)
):
logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id)
snapshot.unrestorable = True
unrestorable_snapshots.append(snapshot.snapshot_id)
updated_ts = now_timestamp()
unpaused_ts = to_timestamp(unpaused_dt)

if unpaused_snapshots:
for unpaused_ts, snapshot_ids in unpaused_snapshots.items():
self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts)
# Pause all snapshots with target names first
for where in snapshot_name_filter(
[s.name for s in snapshots],
batch_size=self.SNAPSHOT_BATCH_SIZE,
):
self.engine_adapter.update_table(
self.snapshots_table,
{"unpaused_ts": None, "updated_ts": updated_ts},
where=where,
)

if paused_snapshots:
self._update_snapshots(paused_snapshots, unpaused_ts=None)
# Now unpause the target snapshots
self._update_snapshots(
[s.snapshot_id for s in snapshots],
unpaused_ts=unpaused_ts,
updated_ts=updated_ts,
)

if unrestorable_snapshots:
self._update_snapshots(unrestorable_snapshots, unrestorable=True)
# Mark unrestorable snapshots
for forward_only, snapshot_names in unrestorable_snapshots_by_forward_only.items():
forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only))
for where in snapshot_name_filter(
snapshot_names,
batch_size=self.SNAPSHOT_BATCH_SIZE,
):
self.engine_adapter.update_table(
self.snapshots_table,
{"unrestorable": True, "updated_ts": updated_ts},
where=forward_only_exp.and_(where),
)

def get_expired_snapshots(
self,
Expand Down Expand Up @@ -414,7 +378,8 @@ def _update_snapshots(
**kwargs: t.Any,
) -> None:
properties = kwargs
properties["updated_ts"] = now_timestamp()
if "updated_ts" not in properties:
properties["updated_ts"] = now_timestamp()

for where in snapshot_id_filter(
self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE
Expand Down Expand Up @@ -466,13 +431,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
updated_ts,
unpaused_ts,
unrestorable,
forward_only,
next_auto_restatement_ts,
) in fetchall(self.engine_adapter, query):
snapshot = parse_snapshot(
serialized_snapshot=serialized_snapshot,
updated_ts=updated_ts,
unpaused_ts=unpaused_ts,
unrestorable=unrestorable,
forward_only=forward_only,
next_auto_restatement_ts=next_auto_restatement_ts,
)
snapshot_id = snapshot.snapshot_id
Expand Down Expand Up @@ -502,6 +469,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
"next_auto_restatement_ts",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
Expand All @@ -528,13 +496,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
updated_ts,
unpaused_ts,
unrestorable,
forward_only,
next_auto_restatement_ts,
) in fetchall(self.engine_adapter, query):
snapshot_id = SnapshotId(name=name, identifier=identifier)
snapshot = snapshots[snapshot_id]
snapshot.updated_ts = updated_ts
snapshot.unpaused_ts = unpaused_ts
snapshot.unrestorable = unrestorable
snapshot.forward_only = forward_only
snapshot.next_auto_restatement_ts = next_auto_restatement_ts
cached_snapshots_in_state.add(snapshot_id)

Expand Down Expand Up @@ -568,6 +538,7 @@ def _get_snapshots_expressions(
"snapshots.updated_ts",
"snapshots.unpaused_ts",
"snapshots.unrestorable",
"snapshots.forward_only",
"auto_restatements.next_auto_restatement_ts",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
Expand Down Expand Up @@ -623,6 +594,7 @@ def _get_snapshots_with_same_version(
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
.where(where)
Expand All @@ -640,9 +612,10 @@ def _get_snapshots_with_same_version(
updated_ts=updated_ts,
unpaused_ts=unpaused_ts,
unrestorable=unrestorable,
forward_only=forward_only,
snapshot=snapshot,
)
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
]


Expand All @@ -651,6 +624,7 @@ def parse_snapshot(
updated_ts: int,
unpaused_ts: t.Optional[int],
unrestorable: bool,
forward_only: bool,
next_auto_restatement_ts: t.Optional[int],
) -> Snapshot:
return Snapshot(
Expand All @@ -659,6 +633,7 @@ def parse_snapshot(
"updated_ts": updated_ts,
"unpaused_ts": unpaused_ts,
"unrestorable": unrestorable,
"forward_only": forward_only,
"next_auto_restatement_ts": next_auto_restatement_ts,
}
)
Expand All @@ -673,6 +648,7 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
"next_auto_restatement_ts",
}
)
Expand All @@ -693,6 +669,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
"unpaused_ts": snapshot.unpaused_ts,
"ttl_ms": snapshot.ttl_ms,
"unrestorable": snapshot.unrestorable,
"forward_only": snapshot.forward_only,
}
for snapshot in snapshots
]
Expand Down Expand Up @@ -762,19 +739,10 @@ def full_snapshot(self) -> Snapshot:
"updated_ts": self.updated_ts,
"unpaused_ts": self.unpaused_ts,
"unrestorable": self.unrestorable,
"forward_only": self.forward_only,
}
)

def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
"""Sets the timestamp for when this snapshot was unpaused.

Args:
unpaused_dt: The datetime object of when this snapshot was unpaused.
"""
self.unpaused_ts = (
to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
)

@classmethod
def from_snapshot_record(
cls,
Expand All @@ -785,6 +753,7 @@ def from_snapshot_record(
updated_ts: int,
unpaused_ts: t.Optional[int],
unrestorable: bool,
forward_only: bool,
snapshot: str,
) -> SharedVersionSnapshot:
raw_snapshot = json.loads(snapshot)
Expand All @@ -803,5 +772,5 @@ def from_snapshot_record(
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
effective_from=raw_snapshot.get("effective_from"),
raw_snapshot=raw_snapshot,
forward_only=raw_snapshot.get("forward_only", False),
forward_only=forward_only,
)
15 changes: 15 additions & 0 deletions sqlmesh/core/state_sync/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@
T = t.TypeVar("T")


def snapshot_name_filter(
snapshot_names: t.Iterable[str],
batch_size: int,
alias: t.Optional[str] = None,
) -> t.Iterator[exp.Condition]:
names = sorted(snapshot_names)

if not names:
yield exp.false()
else:
batches = create_batches(names, batch_size=batch_size)
for names in batches:
yield exp.column("name", table=alias).isin(*names)


def snapshot_id_filter(
engine_adapter: EngineAdapter,
snapshot_ids: t.Iterable[SnapshotIdLike],
Expand Down
Loading