Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/snapshot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Node as Node,
QualifiedViewName as QualifiedViewName,
Snapshot as Snapshot,
SnapshotIdAndVersion as SnapshotIdAndVersion,
SnapshotChangeCategory as SnapshotChangeCategory,
SnapshotDataVersion as SnapshotDataVersion,
SnapshotFingerprint as SnapshotFingerprint,
Expand Down
37 changes: 35 additions & 2 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,37 @@ def name_version(self) -> SnapshotNameVersion:
return SnapshotNameVersion(name=self.name, version=self.version)


class SnapshotIdAndVersion(PydanticModel):
"""A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table
without the overhead of parsing the full snapshot payload and fetching intervals.
"""

name: str
version: str
dev_version_: t.Optional[str] = Field(alias="dev_version")
identifier: str
fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint")

@property
def snapshot_id(self) -> SnapshotId:
return SnapshotId(name=self.name, identifier=self.identifier)

@property
def name_version(self) -> SnapshotNameVersion:
return SnapshotNameVersion(name=self.name, version=self.version)

@property
def fingerprint(self) -> SnapshotFingerprint:
value = self.fingerprint_
if isinstance(value, str):
self.fingerprint_ = value = SnapshotFingerprint.parse_raw(value)
return value

@property
def dev_version(self) -> str:
return self.dev_version_ or self.fingerprint.to_version()


class Snapshot(PydanticModel, SnapshotInfoMixin):
"""A snapshot represents a node at a certain point in time.

Expand Down Expand Up @@ -1463,9 +1494,11 @@ class SnapshotTableCleanupTask(PydanticModel):
dev_table_only: bool


SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, Snapshot]
SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot]
SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot]
SnapshotNameVersionLike = t.Union[SnapshotNameVersion, SnapshotTableInfo, Snapshot]
SnapshotNameVersionLike = t.Union[
SnapshotNameVersion, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot
]


class DeployabilityIndex(PydanticModel, frozen=True):
Expand Down
19 changes: 19 additions & 0 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SnapshotTableCleanupTask,
SnapshotTableInfo,
SnapshotNameVersion,
SnapshotIdAndVersion,
)
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
from sqlmesh.utils import major_minor
Expand Down Expand Up @@ -97,6 +98,24 @@ def get_snapshots(
A dictionary of snapshot ids to snapshots for ones that could be found.
"""

@abc.abstractmethod
def get_snapshots_by_names(
self,
snapshot_names: t.Iterable[str],
current_ts: t.Optional[int] = None,
exclude_expired: bool = True,
) -> t.Set[SnapshotIdAndVersion]:
"""Return the snapshot records for all versions of the specified snapshot names.

Args:
snapshot_names: Iterable of snapshot names to fetch all snapshot records for
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result

Returns:
A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots()
"""

@abc.abstractmethod
def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
"""Checks if multiple snapshots exist in the state sync.
Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary
from sqlmesh.core.snapshot import (
Snapshot,
SnapshotIdAndVersion,
SnapshotId,
SnapshotIdLike,
SnapshotInfoLike,
Expand Down Expand Up @@ -366,6 +367,16 @@ def get_snapshots(
Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals)
return snapshots

def get_snapshots_by_names(
self,
snapshot_names: t.Iterable[str],
current_ts: t.Optional[int] = None,
exclude_expired: bool = True,
) -> t.Set[SnapshotIdAndVersion]:
return self.snapshot_state.get_snapshots_by_names(
snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired
)

@transactional()
def add_interval(
self,
Expand Down
75 changes: 50 additions & 25 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path
from collections import defaultdict
from sqlglot import exp
from pydantic import Field

from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.state_sync.db.utils import (
Expand All @@ -27,12 +26,12 @@
SnapshotNameVersion,
SnapshotInfoLike,
Snapshot,
SnapshotIdAndVersion,
SnapshotId,
SnapshotFingerprint,
)
from sqlmesh.utils.migration import index_text_type, blob_text_type
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
from sqlmesh.utils.pydantic import PydanticModel
from sqlmesh.utils import unique

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -215,7 +214,7 @@ def _get_expired_snapshots(
for snapshot in environment.snapshots
}

def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool:
return (
snapshot.snapshot_id in promoted_snapshot_ids
or snapshot.snapshot_id not in expired_candidates
Expand Down Expand Up @@ -308,6 +307,52 @@ def get_snapshots(
"""
return self._get_snapshots(snapshot_ids)

def get_snapshots_by_names(
self,
snapshot_names: t.Iterable[str],
current_ts: t.Optional[int] = None,
exclude_expired: bool = True,
) -> t.Set[SnapshotIdAndVersion]:
"""Return the snapshot records for all versions of the specified snapshot names.

Args:
snapshot_names: Iterable of snapshot names to fetch all snapshot records for
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result

Returns:
A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots()
"""
if not snapshot_names:
return set()

if exclude_expired:
current_ts = current_ts or now_timestamp()
unexpired_expr = (exp.column("updated_ts") + exp.column("ttl_ms")) > current_ts
else:
unexpired_expr = None

return {
SnapshotIdAndVersion(
name=name,
identifier=identifier,
version=version,
dev_version=dev_version,
fingerprint=fingerprint,
)
for where in snapshot_name_filter(
snapshot_names=snapshot_names,
batch_size=self.SNAPSHOT_BATCH_SIZE,
)
for name, identifier, version, dev_version, fingerprint in fetchall(
self.engine_adapter,
exp.select("name", "identifier", "version", "dev_version", "fingerprint")
.from_(self.snapshots_table)
.where(where)
.and_(unexpired_expr),
)
}

def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
"""Checks if snapshots exist.

Expand Down Expand Up @@ -591,7 +636,7 @@ def _get_snapshots_with_same_version(
self,
snapshots: t.Collection[SnapshotNameVersionLike],
lock_for_update: bool = False,
) -> t.List[SharedVersionSnapshot]:
) -> t.List[SnapshotIdAndVersion]:
"""Fetches all snapshots that share the same version as the snapshots.

The output includes the snapshots with the specified identifiers.
Expand Down Expand Up @@ -628,7 +673,7 @@ def _get_snapshots_with_same_version(
snapshot_rows.extend(fetchall(self.engine_adapter, query))

return [
SharedVersionSnapshot(
SnapshotIdAndVersion(
name=name,
identifier=identifier,
version=version,
Expand Down Expand Up @@ -711,23 +756,3 @@ def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]
for name_version, ts in auto_restatements.items()
]
)


class SharedVersionSnapshot(PydanticModel):
"""A stripped down version of a snapshot that is used for fetching snapshots that share the same version
with a significantly reduced parsing overhead.
"""

name: str
version: str
dev_version_: t.Optional[str] = Field(alias="dev_version")
identifier: str
fingerprint: SnapshotFingerprint

@property
def snapshot_id(self) -> SnapshotId:
return SnapshotId(name=self.name, identifier=self.identifier)

@property
def dev_version(self) -> str:
return self.dev_version_ or self.fingerprint.to_version()
87 changes: 87 additions & 0 deletions tests/core/state_sync/test_state_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3569,3 +3569,90 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync):
"@grant_schema_usage()",
"@grant_select_privileges()",
]


def test_get_snapshots_by_names(
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
):
assert state_sync.get_snapshots_by_names(snapshot_names=[]) == set()

snap_a_v1, snap_a_v2 = (
make_snapshot(
SqlModel(
name="a",
query=parse_one(f"select {i}, ds"),
),
version="a",
)
for i in range(2)
)

snap_b = make_snapshot(
SqlModel(
name="b",
query=parse_one(f"select 'b' as b, ds"),
),
version="b",
)

state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b])

assert {s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'])} == {
snap_a_v1.snapshot_id,
snap_a_v2.snapshot_id,
}
assert {
s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"', '"b"'])
} == {
snap_a_v1.snapshot_id,
snap_a_v2.snapshot_id,
snap_b.snapshot_id,
}


def test_get_snapshots_by_names_include_expired(
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
):
now_ts = now_timestamp()

normal_a = make_snapshot(
SqlModel(
name="a",
query=parse_one(f"select 1, ds"),
),
version="a",
)

expired_a = make_snapshot(
SqlModel(
name="a",
query=parse_one(f"select 2, ds"),
),
version="a",
ttl="in 10 seconds",
)
expired_a.updated_ts = now_ts - (
1000 * 15
) # last updated 15 seconds ago, expired 10 seconds from last updated = expired 5 seconds ago

state_sync.push_snapshots([normal_a, expired_a])

assert {
s.snapshot_id
for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], current_ts=now_ts)
} == {normal_a.snapshot_id}
assert {
s.snapshot_id
for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], exclude_expired=False)
} == {
normal_a.snapshot_id,
expired_a.snapshot_id,
}

# wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query
assert {
s.snapshot_id
for s in state_sync.get_snapshots_by_names(
snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000))
)
} == {normal_a.snapshot_id, expired_a.snapshot_id}
33 changes: 33 additions & 0 deletions tests/core/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
QualifiedViewName,
Snapshot,
SnapshotId,
SnapshotIdAndVersion,
SnapshotChangeCategory,
SnapshotFingerprint,
SnapshotIntervals,
Expand Down Expand Up @@ -3532,3 +3533,35 @@ def test_table_name_virtual_environment_mode(
assert table_name_result.endswith(snapshot.version)
else:
assert table_name_result.endswith(f"{snapshot.dev_version}__dev")


def test_snapshot_id_and_version_fingerprint_lazy_init():
snapshot = SnapshotIdAndVersion(
name="a",
identifier="1234",
version="2345",
dev_version=None,
fingerprint='{"data_hash":"1","metadata_hash":"2","parent_data_hash":"3","parent_metadata_hash":"4"}',
)

# starts off as a string in the private property
assert isinstance(snapshot.fingerprint_, str)

# gets parsed into SnapshotFingerprint on first access of public property
fingerprint = snapshot.fingerprint
assert isinstance(fingerprint, SnapshotFingerprint)
assert isinstance(snapshot.fingerprint_, SnapshotFingerprint)

assert fingerprint.data_hash == "1"
assert fingerprint.metadata_hash == "2"
assert fingerprint.parent_data_hash == "3"
assert fingerprint.parent_metadata_hash == "4"
assert snapshot.dev_version is not None # dev version uses fingerprint

# can also be supplied as a SnapshotFingerprint to begin with instead of a str
snapshot = SnapshotIdAndVersion(
name="a", identifier="1234", version="2345", dev_version=None, fingerprint=fingerprint
)

assert isinstance(snapshot.fingerprint_, SnapshotFingerprint)
assert snapshot.fingerprint == fingerprint