From f686ad19410f95a20bed0ceb03ed6922ed222479 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 1 Sep 2025 23:35:25 +0000 Subject: [PATCH 1/4] Feat(state_sync): Add the ability to fetch all versions of a snapshot by name --- sqlmesh/core/state_sync/base.py | 18 ++++++ sqlmesh/core/state_sync/db/facade.py | 10 +++ sqlmesh/core/state_sync/db/snapshot.py | 40 ++++++++++++ tests/core/state_sync/test_state_sync.py | 78 ++++++++++++++++++++++++ 4 files changed, 146 insertions(+) diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 4d3d51a469..f6166d729f 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -97,6 +97,24 @@ def get_snapshots( A dictionary of snapshot ids to snapshots for ones that could be found. """ + @abc.abstractmethod + def get_snapshot_ids_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotId]: + """Return the snapshot id's for all versions of the specified snapshot names. + + Args: + snapshot_names: Iterable of snapshot names to fetch all snapshot id's for + current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) + exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result + + Returns: + A dictionary mapping snapshot names to a list of relevant snapshot id's + """ + @abc.abstractmethod def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: """Checks if multiple snapshots exist in the state sync. diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 898ba75651..ecadfeff6f 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -366,6 +366,16 @@ def get_snapshots( Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) return snapshots + def get_snapshot_ids_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotId]: + return self.snapshot_state.get_snapshot_ids_by_names( + snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired + ) + @transactional() def add_interval( self, diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index af10f0192e..6fc85782f0 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -308,6 +308,46 @@ def get_snapshots( """ return self._get_snapshots(snapshot_ids) + def get_snapshot_ids_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotId]: + """Return the snapshot id's for all versions of the specified snapshot names. + + Args: + snapshot_names: Iterable of snapshot names to fetch all snapshot id's for + current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) + exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result + + Returns: + A dictionary mapping snapshot names to a list of relevant snapshot id's + """ + if not snapshot_names: + return set() + + if exclude_expired: + current_ts = current_ts or now_timestamp() + unexpired_expr = (exp.column("updated_ts") + exp.column("ttl_ms")) > current_ts + else: + unexpired_expr = None + + return { + SnapshotId(name=name, identifier=identifier) + for where in snapshot_name_filter( + snapshot_names=snapshot_names, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + for name, identifier in fetchall( + self.engine_adapter, + exp.select("name", "identifier") + .from_(self.snapshots_table) + .where(where) + .and_(unexpired_expr), + ) + } + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: """Checks if snapshots exist. diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index be8e4ad3e0..01cf0ed641 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -3569,3 +3569,81 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync): "@grant_schema_usage()", "@grant_select_privileges()", ] + + +def test_get_snapshot_ids_by_names( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] +): + assert state_sync.get_snapshot_ids_by_names(snapshot_names=[]) == set() + + snap_a_v1, snap_a_v2 = ( + make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select {i}, ds"), + ), + version="a", + ) + for i in range(2) + ) + + snap_b = make_snapshot( + SqlModel( + name="b", + query=parse_one(f"select 'b' as b, ds"), + ), + version="b", + ) + + state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b]) + + assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"']) == { + snap_a_v1.snapshot_id, + snap_a_v2.snapshot_id, + } + assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"', '"b"']) == { + snap_a_v1.snapshot_id, + snap_a_v2.snapshot_id, + snap_b.snapshot_id, + } + + +def test_get_snapshot_ids_by_names_include_expired( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] +): + now_ts = now_timestamp() + + normal_a = make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select 1, ds"), + ), + version="a", + ) + + expired_a = make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select 2, ds"), + ), + version="a", + ttl="in 10 seconds", + ) + expired_a.updated_ts = now_ts - ( + 1000 * 15 + ) # last updated 15 seconds ago, expired 10 seconds from last updated = expired 5 seconds ago + + state_sync.push_snapshots([normal_a, expired_a]) + + assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], current_ts=now_ts) == { + normal_a.snapshot_id + } + assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], exclude_expired=False) == { + normal_a.snapshot_id, + expired_a.snapshot_id, + } + + # wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query + assert state_sync.get_snapshot_ids_by_names( + snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000)) + ) == {normal_a.snapshot_id, expired_a.snapshot_id} From 28d684a1b2a1b329d8776d56a00daabcd136debe Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Fri, 5 Sep 2025 04:15:21 +0000 Subject: [PATCH 2/4] PR feedback --- sqlmesh/core/snapshot/__init__.py | 1 + sqlmesh/core/snapshot/definition.py | 24 +++++++++++++ sqlmesh/core/state_sync/base.py | 9 ++--- sqlmesh/core/state_sync/db/facade.py | 7 ++-- sqlmesh/core/state_sync/db/snapshot.py | 45 ++++++++---------------- tests/core/state_sync/test_state_sync.py | 33 ++++++++++------- 6 files changed, 70 insertions(+), 49 deletions(-) diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index 8ad574f8ac..e8e8871b6a 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -4,6 +4,7 @@ Node as Node, QualifiedViewName as QualifiedViewName, Snapshot as Snapshot, + MinimalSnapshot as MinimalSnapshot, SnapshotChangeCategory as SnapshotChangeCategory, SnapshotDataVersion as SnapshotDataVersion, SnapshotFingerprint as SnapshotFingerprint, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index dea4ef64e5..dca8c164ed 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -588,6 +588,30 @@ def name_version(self) -> SnapshotNameVersion: return SnapshotNameVersion(name=self.name, version=self.version) +class MinimalSnapshot(PydanticModel): + """A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table + without the overhead of parsing the full snapshot payload and fetching intervals. + """ + + name: str + version: str + dev_version_: t.Optional[str] = Field(alias="dev_version") + identifier: str + fingerprint: SnapshotFingerprint + + @property + def snapshot_id(self) -> SnapshotId: + return SnapshotId(name=self.name, identifier=self.identifier) + + @property + def name_version(self) -> SnapshotNameVersion: + return SnapshotNameVersion(name=self.name, version=self.version) + + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + + class Snapshot(PydanticModel, SnapshotInfoMixin): """A snapshot represents a node at a certain point in time. diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index f6166d729f..b12af65415 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -23,6 +23,7 @@ SnapshotTableCleanupTask, SnapshotTableInfo, SnapshotNameVersion, + MinimalSnapshot, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals from sqlmesh.utils import major_minor @@ -98,21 +99,21 @@ def get_snapshots( """ @abc.abstractmethod - def get_snapshot_ids_by_names( + def get_snapshots_by_names( self, snapshot_names: t.Iterable[str], current_ts: t.Optional[int] = None, exclude_expired: bool = True, - ) -> t.Set[SnapshotId]: + ) -> t.Set[MinimalSnapshot]: """Return the snapshot id's for all versions of the specified snapshot names. Args: - snapshot_names: Iterable of snapshot names to fetch all snapshot id's for + snapshot_names: Iterable of snapshot names to fetch all snapshot for current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result Returns: - A dictionary mapping snapshot names to a list of relevant snapshot id's + A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots() """ @abc.abstractmethod diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index ecadfeff6f..863c8c8a2e 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -29,6 +29,7 @@ from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary from sqlmesh.core.snapshot import ( Snapshot, + MinimalSnapshot, SnapshotId, SnapshotIdLike, SnapshotInfoLike, @@ -366,13 +367,13 @@ def get_snapshots( Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) return snapshots - def get_snapshot_ids_by_names( + def get_snapshots_by_names( self, snapshot_names: t.Iterable[str], current_ts: t.Optional[int] = None, exclude_expired: bool = True, - ) -> t.Set[SnapshotId]: - return self.snapshot_state.get_snapshot_ids_by_names( + ) -> t.Set[MinimalSnapshot]: + return self.snapshot_state.get_snapshots_by_names( snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired ) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 6fc85782f0..7933f6b4db 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -6,7 +6,6 @@ from pathlib import Path from collections import defaultdict from sqlglot import exp -from pydantic import Field from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.state_sync.db.utils import ( @@ -27,12 +26,12 @@ SnapshotNameVersion, SnapshotInfoLike, Snapshot, + MinimalSnapshot, SnapshotId, SnapshotFingerprint, ) from sqlmesh.utils.migration import index_text_type, blob_text_type from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp -from sqlmesh.utils.pydantic import PydanticModel from sqlmesh.utils import unique if t.TYPE_CHECKING: @@ -215,7 +214,7 @@ def _get_expired_snapshots( for snapshot in environment.snapshots } - def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: + def _is_snapshot_used(snapshot: MinimalSnapshot) -> bool: return ( snapshot.snapshot_id in promoted_snapshot_ids or snapshot.snapshot_id not in expired_candidates @@ -308,12 +307,12 @@ def get_snapshots( """ return self._get_snapshots(snapshot_ids) - def get_snapshot_ids_by_names( + def get_snapshots_by_names( self, snapshot_names: t.Iterable[str], current_ts: t.Optional[int] = None, exclude_expired: bool = True, - ) -> t.Set[SnapshotId]: + ) -> t.Set[MinimalSnapshot]: """Return the snapshot id's for all versions of the specified snapshot names. Args: @@ -334,14 +333,20 @@ def get_snapshot_ids_by_names( unexpired_expr = None return { - SnapshotId(name=name, identifier=identifier) + MinimalSnapshot( + name=name, + identifier=identifier, + version=version, + dev_version=dev_version, + fingerprint=SnapshotFingerprint.parse_raw(fingerprint), + ) for where in snapshot_name_filter( snapshot_names=snapshot_names, batch_size=self.SNAPSHOT_BATCH_SIZE, ) - for name, identifier in fetchall( + for name, identifier, version, dev_version, fingerprint in fetchall( self.engine_adapter, - exp.select("name", "identifier") + exp.select("name", "identifier", "version", "dev_version", "fingerprint") .from_(self.snapshots_table) .where(where) .and_(unexpired_expr), @@ -631,7 +636,7 @@ def _get_snapshots_with_same_version( self, snapshots: t.Collection[SnapshotNameVersionLike], lock_for_update: bool = False, - ) -> t.List[SharedVersionSnapshot]: + ) -> t.List[MinimalSnapshot]: """Fetches all snapshots that share the same version as the snapshots. The output includes the snapshots with the specified identifiers. @@ -668,7 +673,7 @@ def _get_snapshots_with_same_version( snapshot_rows.extend(fetchall(self.engine_adapter, query)) return [ - SharedVersionSnapshot( + MinimalSnapshot( name=name, identifier=identifier, version=version, @@ -751,23 +756,3 @@ def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int] for name_version, ts in auto_restatements.items() ] ) - - -class SharedVersionSnapshot(PydanticModel): - """A stripped down version of a snapshot that is used for fetching snapshots that share the same version - with a significantly reduced parsing overhead. - """ - - name: str - version: str - dev_version_: t.Optional[str] = Field(alias="dev_version") - identifier: str - fingerprint: SnapshotFingerprint - - @property - def snapshot_id(self) -> SnapshotId: - return SnapshotId(name=self.name, identifier=self.identifier) - - @property - def dev_version(self) -> str: - return self.dev_version_ or self.fingerprint.to_version() diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 01cf0ed641..327ec82210 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -3571,10 +3571,10 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync): ] -def test_get_snapshot_ids_by_names( +def test_get_snapshots_by_names( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] ): - assert state_sync.get_snapshot_ids_by_names(snapshot_names=[]) == set() + assert state_sync.get_snapshots_by_names(snapshot_names=[]) == set() snap_a_v1, snap_a_v2 = ( make_snapshot( @@ -3597,18 +3597,20 @@ def test_get_snapshot_ids_by_names( state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b]) - assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"']) == { + assert {s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'])} == { snap_a_v1.snapshot_id, snap_a_v2.snapshot_id, } - assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"', '"b"']) == { + assert { + s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"', '"b"']) + } == { snap_a_v1.snapshot_id, snap_a_v2.snapshot_id, snap_b.snapshot_id, } -def test_get_snapshot_ids_by_names_include_expired( +def test_get_snapshots_by_names_include_expired( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] ): now_ts = now_timestamp() @@ -3635,15 +3637,22 @@ def test_get_snapshot_ids_by_names_include_expired( state_sync.push_snapshots([normal_a, expired_a]) - assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], current_ts=now_ts) == { - normal_a.snapshot_id - } - assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], exclude_expired=False) == { + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], current_ts=now_ts) + } == {normal_a.snapshot_id} + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], exclude_expired=False) + } == { normal_a.snapshot_id, expired_a.snapshot_id, } # wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query - assert state_sync.get_snapshot_ids_by_names( - snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000)) - ) == {normal_a.snapshot_id, expired_a.snapshot_id} + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names( + snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000)) + ) + } == {normal_a.snapshot_id, expired_a.snapshot_id} From 933bd9ed0fe1ae5bb9beda0a41227a88f926a084 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Fri, 5 Sep 2025 04:20:13 +0000 Subject: [PATCH 3/4] cosmetic enhancements --- sqlmesh/core/snapshot/definition.py | 4 ++-- sqlmesh/core/state_sync/base.py | 4 ++-- sqlmesh/core/state_sync/db/snapshot.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index dca8c164ed..75eaf6bac3 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1487,9 +1487,9 @@ class SnapshotTableCleanupTask(PydanticModel): dev_table_only: bool -SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, Snapshot] +SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, MinimalSnapshot, Snapshot] SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot] -SnapshotNameVersionLike = t.Union[SnapshotNameVersion, SnapshotTableInfo, Snapshot] +SnapshotNameVersionLike = t.Union[SnapshotNameVersion, SnapshotTableInfo, MinimalSnapshot, Snapshot] class DeployabilityIndex(PydanticModel, frozen=True): diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index b12af65415..2024a0b578 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -105,10 +105,10 @@ def get_snapshots_by_names( current_ts: t.Optional[int] = None, exclude_expired: bool = True, ) -> t.Set[MinimalSnapshot]: - """Return the snapshot id's for all versions of the specified snapshot names. + """Return the snapshot records for all versions of the specified snapshot names. Args: - snapshot_names: Iterable of snapshot names to fetch all snapshot for + snapshot_names: Iterable of snapshot names to fetch all snapshot records for current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 7933f6b4db..839deeb76a 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -313,15 +313,15 @@ def get_snapshots_by_names( current_ts: t.Optional[int] = None, exclude_expired: bool = True, ) -> t.Set[MinimalSnapshot]: - """Return the snapshot id's for all versions of the specified snapshot names. + """Return the snapshot records for all versions of the specified snapshot names. Args: - snapshot_names: Iterable of snapshot names to fetch all snapshot id's for + snapshot_names: Iterable of snapshot names to fetch all snapshot records for current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result Returns: - A dictionary mapping snapshot names to a list of relevant snapshot id's + A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots() """ if not snapshot_names: return set() From f234322af56eb6736881beb92a309b2b61873ea9 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Sun, 7 Sep 2025 21:08:29 +0000 Subject: [PATCH 4/4] PR feedback --- sqlmesh/core/snapshot/__init__.py | 2 +- sqlmesh/core/snapshot/definition.py | 17 +++++++++---- sqlmesh/core/state_sync/base.py | 4 ++-- sqlmesh/core/state_sync/db/facade.py | 4 ++-- sqlmesh/core/state_sync/db/snapshot.py | 14 +++++------ tests/core/test_snapshot.py | 33 ++++++++++++++++++++++++++ 6 files changed, 58 insertions(+), 16 deletions(-) diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index e8e8871b6a..32842cc4b2 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -4,7 +4,7 @@ Node as Node, QualifiedViewName as QualifiedViewName, Snapshot as Snapshot, - MinimalSnapshot as MinimalSnapshot, + SnapshotIdAndVersion as SnapshotIdAndVersion, SnapshotChangeCategory as SnapshotChangeCategory, SnapshotDataVersion as SnapshotDataVersion, SnapshotFingerprint as SnapshotFingerprint, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 75eaf6bac3..c124c2098f 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -588,7 +588,7 @@ def name_version(self) -> SnapshotNameVersion: return SnapshotNameVersion(name=self.name, version=self.version) -class MinimalSnapshot(PydanticModel): +class SnapshotIdAndVersion(PydanticModel): """A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table without the overhead of parsing the full snapshot payload and fetching intervals. """ @@ -597,7 +597,7 @@ class MinimalSnapshot(PydanticModel): version: str dev_version_: t.Optional[str] = Field(alias="dev_version") identifier: str - fingerprint: SnapshotFingerprint + fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint") @property def snapshot_id(self) -> SnapshotId: @@ -607,6 +607,13 @@ def snapshot_id(self) -> SnapshotId: def name_version(self) -> SnapshotNameVersion: return SnapshotNameVersion(name=self.name, version=self.version) + @property + def fingerprint(self) -> SnapshotFingerprint: + value = self.fingerprint_ + if isinstance(value, str): + self.fingerprint_ = value = SnapshotFingerprint.parse_raw(value) + return value + @property def dev_version(self) -> str: return self.dev_version_ or self.fingerprint.to_version() @@ -1487,9 +1494,11 @@ class SnapshotTableCleanupTask(PydanticModel): dev_table_only: bool -SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, MinimalSnapshot, Snapshot] +SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot] SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot] -SnapshotNameVersionLike = t.Union[SnapshotNameVersion, SnapshotTableInfo, MinimalSnapshot, Snapshot] +SnapshotNameVersionLike = t.Union[ + SnapshotNameVersion, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot +] class DeployabilityIndex(PydanticModel, frozen=True): diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 2024a0b578..a8f73b6937 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -23,7 +23,7 @@ SnapshotTableCleanupTask, SnapshotTableInfo, SnapshotNameVersion, - MinimalSnapshot, + SnapshotIdAndVersion, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals from sqlmesh.utils import major_minor @@ -104,7 +104,7 @@ def get_snapshots_by_names( snapshot_names: t.Iterable[str], current_ts: t.Optional[int] = None, exclude_expired: bool = True, - ) -> t.Set[MinimalSnapshot]: + ) -> t.Set[SnapshotIdAndVersion]: """Return the snapshot records for all versions of the specified snapshot names. Args: diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 863c8c8a2e..85bebcc5d6 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -29,7 +29,7 @@ from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary from sqlmesh.core.snapshot import ( Snapshot, - MinimalSnapshot, + SnapshotIdAndVersion, SnapshotId, SnapshotIdLike, SnapshotInfoLike, @@ -372,7 +372,7 @@ def get_snapshots_by_names( snapshot_names: t.Iterable[str], current_ts: t.Optional[int] = None, exclude_expired: bool = True, - ) -> t.Set[MinimalSnapshot]: + ) -> t.Set[SnapshotIdAndVersion]: return self.snapshot_state.get_snapshots_by_names( snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired ) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 839deeb76a..1904e51c55 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -26,7 +26,7 @@ SnapshotNameVersion, SnapshotInfoLike, Snapshot, - MinimalSnapshot, + SnapshotIdAndVersion, SnapshotId, SnapshotFingerprint, ) @@ -214,7 +214,7 @@ def _get_expired_snapshots( for snapshot in environment.snapshots } - def _is_snapshot_used(snapshot: MinimalSnapshot) -> bool: + def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool: return ( snapshot.snapshot_id in promoted_snapshot_ids or snapshot.snapshot_id not in expired_candidates @@ -312,7 +312,7 @@ def get_snapshots_by_names( snapshot_names: t.Iterable[str], current_ts: t.Optional[int] = None, exclude_expired: bool = True, - ) -> t.Set[MinimalSnapshot]: + ) -> t.Set[SnapshotIdAndVersion]: """Return the snapshot records for all versions of the specified snapshot names. Args: @@ -333,12 +333,12 @@ def get_snapshots_by_names( unexpired_expr = None return { - MinimalSnapshot( + SnapshotIdAndVersion( name=name, identifier=identifier, version=version, dev_version=dev_version, - fingerprint=SnapshotFingerprint.parse_raw(fingerprint), + fingerprint=fingerprint, ) for where in snapshot_name_filter( snapshot_names=snapshot_names, @@ -636,7 +636,7 @@ def _get_snapshots_with_same_version( self, snapshots: t.Collection[SnapshotNameVersionLike], lock_for_update: bool = False, - ) -> t.List[MinimalSnapshot]: + ) -> t.List[SnapshotIdAndVersion]: """Fetches all snapshots that share the same version as the snapshots. The output includes the snapshots with the specified identifiers. @@ -673,7 +673,7 @@ def _get_snapshots_with_same_version( snapshot_rows.extend(fetchall(self.engine_adapter, query)) return [ - MinimalSnapshot( + SnapshotIdAndVersion( name=name, identifier=identifier, version=version, diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 9e36ecc3ae..c37bd57d2e 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -44,6 +44,7 @@ QualifiedViewName, Snapshot, SnapshotId, + SnapshotIdAndVersion, SnapshotChangeCategory, SnapshotFingerprint, SnapshotIntervals, @@ -3532,3 +3533,35 @@ def test_table_name_virtual_environment_mode( assert table_name_result.endswith(snapshot.version) else: assert table_name_result.endswith(f"{snapshot.dev_version}__dev") + + +def test_snapshot_id_and_version_fingerprint_lazy_init(): + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + dev_version=None, + fingerprint='{"data_hash":"1","metadata_hash":"2","parent_data_hash":"3","parent_metadata_hash":"4"}', + ) + + # starts off as a string in the private property + assert isinstance(snapshot.fingerprint_, str) + + # gets parsed into SnapshotFingerprint on first access of public property + fingerprint = snapshot.fingerprint + assert isinstance(fingerprint, SnapshotFingerprint) + assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) + + assert fingerprint.data_hash == "1" + assert fingerprint.metadata_hash == "2" + assert fingerprint.parent_data_hash == "3" + assert fingerprint.parent_metadata_hash == "4" + assert snapshot.dev_version is not None # dev version uses fingerprint + + # can also be supplied as a SnapshotFingerprint to begin with instead of a str + snapshot = SnapshotIdAndVersion( + name="a", identifier="1234", version="2345", dev_version=None, fingerprint=fingerprint + ) + + assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) + assert snapshot.fingerprint == fingerprint