Skip to content

Commit b82036f

Browse files
committed
Feat(state_sync): Add the ability to fetch all versions of a snapshot by name
1 parent 2eb39a4 commit b82036f

File tree

4 files changed

+146
-0
lines changed

4 files changed

+146
-0
lines changed

sqlmesh/core/state_sync/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ def get_snapshots(
9797
A dictionary of snapshot ids to snapshots for ones that could be found.
9898
"""
9999

100+
@abc.abstractmethod
101+
def get_snapshot_ids_by_names(
102+
self,
103+
snapshot_names: t.Iterable[str],
104+
current_ts: t.Optional[int] = None,
105+
exclude_expired: bool = True,
106+
) -> t.Set[SnapshotId]:
107+
"""Return the snapshot id's for all versions of the specified snapshot names.
108+
109+
Args:
110+
snapshot_names: Iterable of snapshot names to fetch all snapshot id's for
111+
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
112+
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result
113+
114+
Returns:
115+
A dictionary mapping snapshot names to a list of relevant snapshot id's
116+
"""
117+
100118
@abc.abstractmethod
101119
def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
102120
"""Checks if multiple snapshots exist in the state sync.

sqlmesh/core/state_sync/db/facade.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ def get_snapshots(
368368
Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals)
369369
return snapshots
370370

371+
def get_snapshot_ids_by_names(
372+
self,
373+
snapshot_names: t.Iterable[str],
374+
current_ts: t.Optional[int] = None,
375+
exclude_expired: bool = True,
376+
) -> t.Set[SnapshotId]:
377+
return self.snapshot_state.get_snapshot_ids_by_names(
378+
snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired
379+
)
380+
371381
@transactional()
372382
def add_interval(
373383
self,

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,46 @@ def get_snapshots(
288288
"""
289289
return self._get_snapshots(snapshot_ids)
290290

291+
def get_snapshot_ids_by_names(
292+
self,
293+
snapshot_names: t.Iterable[str],
294+
current_ts: t.Optional[int] = None,
295+
exclude_expired: bool = True,
296+
) -> t.Set[SnapshotId]:
297+
"""Return the snapshot id's for all versions of the specified snapshot names.
298+
299+
Args:
300+
snapshot_names: Iterable of snapshot names to fetch all snapshot id's for
301+
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
302+
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result
303+
304+
Returns:
305+
A dictionary mapping snapshot names to a list of relevant snapshot id's
306+
"""
307+
if not snapshot_names:
308+
return set()
309+
310+
if exclude_expired:
311+
current_ts = current_ts or now_timestamp()
312+
unexpired_expr = (exp.column("updated_ts") + exp.column("ttl_ms")) > current_ts
313+
else:
314+
unexpired_expr = None
315+
316+
return {
317+
SnapshotId(name=name, identifier=identifier)
318+
for where in snapshot_name_filter(
319+
snapshot_names=snapshot_names,
320+
batch_size=self.SNAPSHOT_BATCH_SIZE,
321+
)
322+
for name, identifier in fetchall(
323+
self.engine_adapter,
324+
exp.select("name", "identifier")
325+
.from_(self.snapshots_table)
326+
.where(where)
327+
.and_(unexpired_expr),
328+
)
329+
}
330+
291331
def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
292332
"""Checks if snapshots exist.
293333

tests/core/state_sync/test_state_sync.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3554,3 +3554,81 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync):
35543554
"@grant_schema_usage()",
35553555
"@grant_select_privileges()",
35563556
]
3557+
3558+
3559+
def test_get_snapshot_ids_by_names(
3560+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
3561+
):
3562+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=[]) == set()
3563+
3564+
snap_a_v1, snap_a_v2 = (
3565+
make_snapshot(
3566+
SqlModel(
3567+
name="a",
3568+
query=parse_one(f"select {i}, ds"),
3569+
),
3570+
version="a",
3571+
)
3572+
for i in range(2)
3573+
)
3574+
3575+
snap_b = make_snapshot(
3576+
SqlModel(
3577+
name="b",
3578+
query=parse_one(f"select 'b' as b, ds"),
3579+
),
3580+
version="b",
3581+
)
3582+
3583+
state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b])
3584+
3585+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"']) == {
3586+
snap_a_v1.snapshot_id,
3587+
snap_a_v2.snapshot_id,
3588+
}
3589+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"', '"b"']) == {
3590+
snap_a_v1.snapshot_id,
3591+
snap_a_v2.snapshot_id,
3592+
snap_b.snapshot_id,
3593+
}
3594+
3595+
3596+
def test_get_snapshot_ids_by_names_include_expired(
3597+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
3598+
):
3599+
now_ts = now_timestamp()
3600+
3601+
normal_a = make_snapshot(
3602+
SqlModel(
3603+
name="a",
3604+
query=parse_one(f"select 1, ds"),
3605+
),
3606+
version="a",
3607+
)
3608+
3609+
expired_a = make_snapshot(
3610+
SqlModel(
3611+
name="a",
3612+
query=parse_one(f"select 2, ds"),
3613+
),
3614+
version="a",
3615+
ttl="in 10 seconds",
3616+
)
3617+
expired_a.updated_ts = now_ts - (
3618+
1000 * 15
3619+
) # last updated 15 seconds ago, expired 10 seconds from last updated = expired 5 seconds ago
3620+
3621+
state_sync.push_snapshots([normal_a, expired_a])
3622+
3623+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], current_ts=now_ts) == {
3624+
normal_a.snapshot_id
3625+
}
3626+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], exclude_expired=False) == {
3627+
normal_a.snapshot_id,
3628+
expired_a.snapshot_id,
3629+
}
3630+
3631+
# wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query
3632+
assert state_sync.get_snapshot_ids_by_names(
3633+
snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000))
3634+
) == {normal_a.snapshot_id, expired_a.snapshot_id}

0 commit comments

Comments
 (0)