diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index 8ad574f8ac..32842cc4b2 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -4,6 +4,7 @@ Node as Node, QualifiedViewName as QualifiedViewName, Snapshot as Snapshot, + SnapshotIdAndVersion as SnapshotIdAndVersion, SnapshotChangeCategory as SnapshotChangeCategory, SnapshotDataVersion as SnapshotDataVersion, SnapshotFingerprint as SnapshotFingerprint, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index dea4ef64e5..c124c2098f 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -588,6 +588,37 @@ def name_version(self) -> SnapshotNameVersion: return SnapshotNameVersion(name=self.name, version=self.version) +class SnapshotIdAndVersion(PydanticModel): + """A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table + without the overhead of parsing the full snapshot payload and fetching intervals. + """ + + name: str + version: str + dev_version_: t.Optional[str] = Field(alias="dev_version") + identifier: str + fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint") + + @property + def snapshot_id(self) -> SnapshotId: + return SnapshotId(name=self.name, identifier=self.identifier) + + @property + def name_version(self) -> SnapshotNameVersion: + return SnapshotNameVersion(name=self.name, version=self.version) + + @property + def fingerprint(self) -> SnapshotFingerprint: + value = self.fingerprint_ + if isinstance(value, str): + self.fingerprint_ = value = SnapshotFingerprint.parse_raw(value) + return value + + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + + class Snapshot(PydanticModel, SnapshotInfoMixin): """A snapshot represents a node at a certain point in time. @@ -1463,9 +1494,11 @@ class SnapshotTableCleanupTask(PydanticModel): dev_table_only: bool -SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, Snapshot] +SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot] SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot] -SnapshotNameVersionLike = t.Union[SnapshotNameVersion, SnapshotTableInfo, Snapshot] +SnapshotNameVersionLike = t.Union[ + SnapshotNameVersion, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot +] class DeployabilityIndex(PydanticModel, frozen=True): diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 4d3d51a469..a8f73b6937 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -23,6 +23,7 @@ SnapshotTableCleanupTask, SnapshotTableInfo, SnapshotNameVersion, + SnapshotIdAndVersion, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals from sqlmesh.utils import major_minor @@ -97,6 +98,24 @@ def get_snapshots( A dictionary of snapshot ids to snapshots for ones that could be found. """ + @abc.abstractmethod + def get_snapshots_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotIdAndVersion]: + """Return the snapshot records for all versions of the specified snapshot names. + + Args: + snapshot_names: Iterable of snapshot names to fetch all snapshot records for + current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) + exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result + + Returns: + A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots() + """ + @abc.abstractmethod def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: """Checks if multiple snapshots exist in the state sync. diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 898ba75651..85bebcc5d6 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -29,6 +29,7 @@ from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary from sqlmesh.core.snapshot import ( Snapshot, + SnapshotIdAndVersion, SnapshotId, SnapshotIdLike, SnapshotInfoLike, @@ -366,6 +367,16 @@ def get_snapshots( Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) return snapshots + def get_snapshots_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotIdAndVersion]: + return self.snapshot_state.get_snapshots_by_names( + snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired + ) + @transactional() def add_interval( self, diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index af10f0192e..1904e51c55 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -6,7 +6,6 @@ from pathlib import Path from collections import defaultdict from sqlglot import exp -from pydantic import Field from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.state_sync.db.utils import ( @@ -27,12 +26,12 @@ SnapshotNameVersion, SnapshotInfoLike, Snapshot, + SnapshotIdAndVersion, SnapshotId, SnapshotFingerprint, ) from sqlmesh.utils.migration import index_text_type, blob_text_type from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp -from sqlmesh.utils.pydantic import PydanticModel from sqlmesh.utils import unique if t.TYPE_CHECKING: @@ -215,7 +214,7 @@ def _get_expired_snapshots( for snapshot in environment.snapshots } - def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: + def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool: return ( snapshot.snapshot_id in promoted_snapshot_ids or snapshot.snapshot_id not in expired_candidates @@ -308,6 +307,52 @@ def get_snapshots( """ return self._get_snapshots(snapshot_ids) + def get_snapshots_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotIdAndVersion]: + """Return the snapshot records for all versions of the specified snapshot names. + + Args: + snapshot_names: Iterable of snapshot names to fetch all snapshot records for + current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) + exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result + + Returns: + A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots() + """ + if not snapshot_names: + return set() + + if exclude_expired: + current_ts = current_ts or now_timestamp() + unexpired_expr = (exp.column("updated_ts") + exp.column("ttl_ms")) > current_ts + else: + unexpired_expr = None + + return { + SnapshotIdAndVersion( + name=name, + identifier=identifier, + version=version, + dev_version=dev_version, + fingerprint=fingerprint, + ) + for where in snapshot_name_filter( + snapshot_names=snapshot_names, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + for name, identifier, version, dev_version, fingerprint in fetchall( + self.engine_adapter, + exp.select("name", "identifier", "version", "dev_version", "fingerprint") + .from_(self.snapshots_table) + .where(where) + .and_(unexpired_expr), + ) + } + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: """Checks if snapshots exist. @@ -591,7 +636,7 @@ def _get_snapshots_with_same_version( self, snapshots: t.Collection[SnapshotNameVersionLike], lock_for_update: bool = False, - ) -> t.List[SharedVersionSnapshot]: + ) -> t.List[SnapshotIdAndVersion]: """Fetches all snapshots that share the same version as the snapshots. The output includes the snapshots with the specified identifiers. @@ -628,7 +673,7 @@ def _get_snapshots_with_same_version( snapshot_rows.extend(fetchall(self.engine_adapter, query)) return [ - SharedVersionSnapshot( + SnapshotIdAndVersion( name=name, identifier=identifier, version=version, @@ -711,23 +756,3 @@ def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int] for name_version, ts in auto_restatements.items() ] ) - - -class SharedVersionSnapshot(PydanticModel): - """A stripped down version of a snapshot that is used for fetching snapshots that share the same version - with a significantly reduced parsing overhead. - """ - - name: str - version: str - dev_version_: t.Optional[str] = Field(alias="dev_version") - identifier: str - fingerprint: SnapshotFingerprint - - @property - def snapshot_id(self) -> SnapshotId: - return SnapshotId(name=self.name, identifier=self.identifier) - - @property - def dev_version(self) -> str: - return self.dev_version_ or self.fingerprint.to_version() diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index be8e4ad3e0..327ec82210 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -3569,3 +3569,90 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync): "@grant_schema_usage()", "@grant_select_privileges()", ] + + +def test_get_snapshots_by_names( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] +): + assert state_sync.get_snapshots_by_names(snapshot_names=[]) == set() + + snap_a_v1, snap_a_v2 = ( + make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select {i}, ds"), + ), + version="a", + ) + for i in range(2) + ) + + snap_b = make_snapshot( + SqlModel( + name="b", + query=parse_one(f"select 'b' as b, ds"), + ), + version="b", + ) + + state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b]) + + assert {s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'])} == { + snap_a_v1.snapshot_id, + snap_a_v2.snapshot_id, + } + assert { + s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"', '"b"']) + } == { + snap_a_v1.snapshot_id, + snap_a_v2.snapshot_id, + snap_b.snapshot_id, + } + + +def test_get_snapshots_by_names_include_expired( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] +): + now_ts = now_timestamp() + + normal_a = make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select 1, ds"), + ), + version="a", + ) + + expired_a = make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select 2, ds"), + ), + version="a", + ttl="in 10 seconds", + ) + expired_a.updated_ts = now_ts - ( + 1000 * 15 + ) # last updated 15 seconds ago, expired 10 seconds from last updated = expired 5 seconds ago + + state_sync.push_snapshots([normal_a, expired_a]) + + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], current_ts=now_ts) + } == {normal_a.snapshot_id} + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], exclude_expired=False) + } == { + normal_a.snapshot_id, + expired_a.snapshot_id, + } + + # wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names( + snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000)) + ) + } == {normal_a.snapshot_id, expired_a.snapshot_id} diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 9e36ecc3ae..c37bd57d2e 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -44,6 +44,7 @@ QualifiedViewName, Snapshot, SnapshotId, + SnapshotIdAndVersion, SnapshotChangeCategory, SnapshotFingerprint, SnapshotIntervals, @@ -3532,3 +3533,35 @@ def test_table_name_virtual_environment_mode( assert table_name_result.endswith(snapshot.version) else: assert table_name_result.endswith(f"{snapshot.dev_version}__dev") + + +def test_snapshot_id_and_version_fingerprint_lazy_init(): + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + dev_version=None, + fingerprint='{"data_hash":"1","metadata_hash":"2","parent_data_hash":"3","parent_metadata_hash":"4"}', + ) + + # starts off as a string in the private property + assert isinstance(snapshot.fingerprint_, str) + + # gets parsed into SnapshotFingerprint on first access of public property + fingerprint = snapshot.fingerprint + assert isinstance(fingerprint, SnapshotFingerprint) + assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) + + assert fingerprint.data_hash == "1" + assert fingerprint.metadata_hash == "2" + assert fingerprint.parent_data_hash == "3" + assert fingerprint.parent_metadata_hash == "4" + assert snapshot.dev_version is not None # dev version uses fingerprint + + # can also be supplied as a SnapshotFingerprint to begin with instead of a str + snapshot = SnapshotIdAndVersion( + name="a", identifier="1234", version="2345", dev_version=None, fingerprint=fingerprint + ) + + assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) + assert snapshot.fingerprint == fingerprint