Skip to content

Commit f9c7625

Browse files
authored
chore: split up get snapshots operations (#2615)
1 parent 28f8b42 commit f9c7625

File tree

2 files changed

+67
-29
lines changed

2 files changed

+67
-29
lines changed

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -459,36 +459,24 @@ def _environments_query(
459459
return query.lock(copy=False)
460460
return query
461461

462-
def _get_snapshots(
462+
def _get_snapshots_expressions(
463463
self,
464464
snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] = None,
465465
lock_for_update: bool = False,
466466
hydrate_seeds: bool = False,
467-
hydrate_intervals: bool = True,
468-
) -> t.Dict[SnapshotId, Snapshot]:
469-
"""Fetches specified snapshots or all snapshots.
470-
471-
Args:
472-
snapshot_ids: The collection of snapshot like objects to fetch.
473-
lock_for_update: Lock the snapshot rows for future update
474-
hydrate_seeds: Whether to hydrate seed snapshots with the content.
475-
hydrate_intervals: Whether to hydrate result snapshots with intervals.
476-
477-
Returns:
478-
A dictionary of snapshot ids to snapshots for ones that could be found.
479-
"""
480-
snapshots: t.Dict[SnapshotId, Snapshot] = {}
481-
duplicates: t.Dict[SnapshotId, Snapshot] = {}
482-
model_cache = ModelCache(self._context_path / c.CACHE)
483-
467+
batch_size: t.Optional[int] = None,
468+
) -> t.Iterator[exp.Expression]:
484469
for where in (
485-
[None] if snapshot_ids is None else self._snapshot_id_filter(snapshot_ids, "snapshots")
470+
[None]
471+
if snapshot_ids is None
472+
else self._snapshot_id_filter(snapshot_ids, alias="snapshots", batch_size=batch_size)
486473
):
487474
query = (
488475
exp.select(
489476
"snapshots.snapshot",
490477
"snapshots.name",
491478
"snapshots.identifier",
479+
"snapshots.version",
492480
)
493481
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
494482
.where(where)
@@ -504,17 +492,43 @@ def _get_snapshots(
504492
),
505493
join_type="left",
506494
)
507-
elif lock_for_update:
495+
else:
496+
query = query.select(exp.Null().as_("content"))
497+
if lock_for_update:
508498
query = query.lock(copy=False)
499+
yield query
509500

510-
for row in self._fetchall(query):
511-
payload = json.loads(row[0])
501+
def _get_snapshots(
502+
self,
503+
snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] = None,
504+
lock_for_update: bool = False,
505+
hydrate_seeds: bool = False,
506+
hydrate_intervals: bool = True,
507+
) -> t.Dict[SnapshotId, Snapshot]:
508+
"""Fetches specified snapshots or all snapshots.
509+
510+
Args:
511+
snapshot_ids: The collection of snapshot like objects to fetch.
512+
lock_for_update: Lock the snapshot rows for future update
513+
hydrate_seeds: Whether to hydrate seed snapshots with the content.
514+
hydrate_intervals: Whether to hydrate result snapshots with intervals.
512515
513-
def loader() -> Node:
514-
return parse_obj_as(Node, payload["node"]) # type: ignore
516+
Returns:
517+
A dictionary of snapshot ids to snapshots for ones that could be found.
518+
"""
519+
snapshots: t.Dict[SnapshotId, Snapshot] = {}
520+
duplicates: t.Dict[SnapshotId, Snapshot] = {}
521+
model_cache = ModelCache(self._context_path / c.CACHE)
515522

516-
payload["node"] = model_cache.get_or_load(f"{row[1]}_{row[2]}", loader=loader) # type: ignore
517-
snapshot = Snapshot(**payload)
523+
for query in self._get_snapshots_expressions(snapshot_ids, lock_for_update, hydrate_seeds):
524+
for serialized_snapshot, name, identifier, _, seed_content in self._fetchall(query):
525+
snapshot = parse_snapshot(
526+
model_cache,
527+
serialized_snapshot=serialized_snapshot,
528+
name=name,
529+
identifier=identifier,
530+
seed_content=seed_content,
531+
)
518532
snapshot_id = snapshot.snapshot_id
519533
if snapshot_id in snapshots:
520534
other = duplicates.get(snapshot_id, snapshots[snapshot_id])
@@ -525,9 +539,6 @@ def loader() -> Node:
525539
else:
526540
snapshots[snapshot_id] = snapshot
527541

528-
if hydrate_seeds and isinstance(snapshot.node, SeedModel) and row[-1]:
529-
snapshot.node = t.cast(SeedModel, snapshot.node).to_hydrated(row[-1])
530-
531542
if snapshots and hydrate_intervals:
532543
_, intervals = self._get_snapshot_intervals(snapshots.values())
533544
Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals)
@@ -1452,6 +1463,27 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
14521463
return snapshot.json(exclude={"intervals", "dev_intervals"})
14531464

14541465

1466+
def parse_snapshot(
1467+
model_cache: ModelCache,
1468+
serialized_snapshot: str,
1469+
name: str,
1470+
identifier: str,
1471+
seed_content: t.Optional[str],
1472+
) -> Snapshot:
1473+
payload = json.loads(serialized_snapshot)
1474+
1475+
def loader() -> Node:
1476+
return parse_obj_as(Node, payload["node"]) # type: ignore
1477+
1478+
payload["node"] = model_cache.get_or_load(f"{name}_{identifier}", loader=loader) # type: ignore
1479+
snapshot = Snapshot(**payload)
1480+
1481+
if seed_content and isinstance(snapshot.node, SeedModel):
1482+
snapshot.node = snapshot.node.to_hydrated(seed_content)
1483+
1484+
return snapshot
1485+
1486+
14551487
class LazilyParsedSnapshots:
14561488
def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]):
14571489
self._raw_snapshots = raw_snapshots

tests/core/test_state_sync.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,13 +2193,17 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot):
21932193
).json(),
21942194
"a",
21952195
"1",
2196+
"1",
2197+
None,
21962198
],
21972199
[
21982200
make_snapshot(
21992201
SqlModel(name="a", query=parse_one("select 2")),
22002202
).json(),
22012203
"a",
22022204
"2",
2205+
"2",
2206+
None,
22032207
],
22042208
],
22052209
[
@@ -2209,6 +2213,8 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot):
22092213
).json(),
22102214
"a",
22112215
"3",
2216+
"3",
2217+
None,
22122218
],
22132219
],
22142220
]

0 commit comments

Comments
 (0)