Skip to content

Commit 532d1c5

Browse files
committed
PR feedback
1 parent 3751c4b commit 532d1c5

File tree

3 files changed

+171
-45
lines changed

3 files changed

+171
-45
lines changed

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SnapshotId,
3131
SnapshotFingerprint,
3232
SnapshotChangeCategory,
33+
snapshots_to_dag,
3334
)
3435
from sqlmesh.utils.migration import index_text_type, blob_text_type
3536
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
@@ -224,11 +225,20 @@ def _get_expired_snapshots(
224225

225226
if not ignore_ttl:
226227
expired_query = expired_query.where(
227-
((exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts)
228-
# we need to include views even if they havent expired in case one depends on a table that /has/ expired
229-
.or_(exp.column("kind_name").eq(ModelKindName.VIEW))
228+
(exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts
230229
)
231230

231+
# we need to include views even if they havent expired in case one depends on a table or view that /has/ expired
232+
# but we only need to do this if there are expired objects to begin with
233+
234+
expired_record_count = fetchone(
235+
self.engine_adapter, exp.select("count(*)").from_(expired_query.subquery())
236+
)
237+
if expired_record_count and expired_record_count[0]:
238+
expired_query = t.cast(
239+
exp.Select, expired_query.or_(exp.column("kind_name").eq(ModelKindName.VIEW))
240+
)
241+
232242
candidates = {
233243
SnapshotId(name=name, identifier=identifier): SnapshotNameVersion(
234244
name=name, version=version
@@ -247,29 +257,23 @@ def _get_expired_snapshots(
247257
# Fetch full snapshots because we need the dependency relationship
248258
full_candidates = self.get_snapshots(candidates.keys())
249259

260+
dag = snapshots_to_dag(full_candidates.values())
261+
250262
# Include any non-expired views that depend on expired tables
251-
for snapshot_id in full_candidates:
263+
for snapshot_id in dag:
252264
snapshot = full_candidates.get(snapshot_id, None)
265+
253266
if not snapshot:
254267
continue
255268

256269
if snapshot.expiration_ts <= current_ts:
257-
# All expired snapshots should be included
270+
# All expired snapshots should be included regardless
258271
expired_candidates[snapshot.snapshot_id] = snapshot.name_version
259272
elif snapshot.model_kind_name == ModelKindName.VIEW:
260-
# With non-expired views, check if they point to an expired parent table
261-
immediate_parents = [
262-
full_candidates[parent_id]
263-
for parent_id in snapshot.parents
264-
if parent_id in full_candidates
265-
]
266-
267-
if any(
268-
parent.expiration_ts <= current_ts
269-
for parent in immediate_parents
270-
if parent.model_kind_name != ModelKindName.VIEW
271-
):
272-
# an immediate upstream table has expired, therefore this view is no longer valid and needs to be cleaned up as well
273+
# Check if any of our parents are in the expired list
274+
# This works because we traverse the dag in topological order, so if our parent either directly
275+
# expired or indirectly expired because /its/ parent expired, it will still be in the expired_snapshots list
276+
if any(parent.snapshot_id in expired_candidates for parent in snapshot.parents):
273277
expired_candidates[snapshot.snapshot_id] = snapshot.name_version
274278

275279
promoted_snapshot_ids = {

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import sqlmesh.core.dialect as d
2525
from sqlmesh.core.dialect import select_from_values
2626
from sqlmesh.core.model import Model, load_sql_based_model
27-
from sqlmesh.core.engine_adapter import EngineAdapter
2827
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
2928
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin, LogicalMergeMixin
3029
from sqlmesh.core.model.definition import create_sql_model
@@ -2816,19 +2815,28 @@ def test_janitor_drops_downstream_unexpired_hard_dependencies(
28162815
- In dev, we modify A - triggers new version of A and a dev preview of B that both expire in 7 days
28172816
- We advance time by 3 days
28182817
- In dev, we modify B - triggers a new version of B that depends on A but expires 3 days after A
2819-
- We advance time by 5 days so that A has reached its expiry but B has not
2818+
- In dev, we create B(view) <- C(view) and B(view) <- D(table)
2819+
- We advance time by 5 days so that A has reached its expiry but B, C and D have not
28202820
- We expire dev so that none of these snapshots are promoted and are thus targets for cleanup
28212821
- We run the janitor
28222822
28232823
Expected outcome:
28242824
- All the dev versions of A and B should be dropped
2825+
- C should be dropped as well because it's a view that depends on B which was dropped
2826+
- D should not be dropped because while it depends on B which was dropped, it's a table so is still valid after B is dropped
28252827
- We should not get a 'ERROR: cannot drop table x because other objects depend on it' on engines that do schema binding
28262828
"""
28272829

2828-
def _state_sync_engine_adapter(context: Context) -> EngineAdapter:
2830+
def _all_snapshot_ids(context: Context) -> t.List[SnapshotId]:
28292831
assert isinstance(context.state_sync, CachingStateSync)
28302832
assert isinstance(context.state_sync.state_sync, EngineAdapterStateSync)
2831-
return context.state_sync.state_sync.engine_adapter
2833+
2834+
return [
2835+
SnapshotId(name=name, identifier=identifier)
2836+
for name, identifier in context.state_sync.state_sync.engine_adapter.fetchall(
2837+
"select name, identifier from sqlmesh._snapshots"
2838+
)
2839+
]
28322840

28332841
models_dir = tmp_path / "models"
28342842
models_dir.mkdir()
@@ -2898,10 +2906,10 @@ def _mutate_config(gateway: str, config: Config):
28982906

28992907
# should now have 4 snapshots in state - 2x model a and 2x model b
29002908
# the new model b is a dev preview because its upstream model changed
2901-
assert (
2902-
len(_state_sync_engine_adapter(sqlmesh).fetchall(f"select * from sqlmesh._snapshots"))
2903-
== 4
2904-
)
2909+
all_snapshot_ids = _all_snapshot_ids(sqlmesh)
2910+
assert len(all_snapshot_ids) == 4
2911+
assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2
2912+
assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 2
29052913

29062914
# context just has the two latest
29072915
assert len(sqlmesh.snapshots) == 2
@@ -2930,19 +2938,39 @@ def _mutate_config(gateway: str, config: Config):
29302938
SELECT a, 'b' as b from {schema}.model_a;
29312939
""")
29322940

2941+
(models_dir / "model_c.sql").write_text(f"""
2942+
MODEL (
2943+
name {schema}.model_c,
2944+
kind VIEW
2945+
);
2946+
2947+
SELECT a, 'c' as c from {schema}.model_b;
2948+
""")
2949+
2950+
(models_dir / "model_d.sql").write_text(f"""
2951+
MODEL (
2952+
name {schema}.model_d,
2953+
kind FULL
2954+
);
2955+
2956+
SELECT a, 'd' as d from {schema}.model_b;
2957+
""")
2958+
29332959
sqlmesh = ctx.create_context(
29342960
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
29352961
)
29362962
sqlmesh.plan(environment="dev", auto_apply=True)
29372963

2938-
# should now have 5 snapshots in state - 2x model a and 3x model b
2939-
assert (
2940-
len(_state_sync_engine_adapter(sqlmesh).fetchall(f"select * from sqlmesh._snapshots"))
2941-
== 5
2942-
)
2964+
# should now have 7 snapshots in state - 2x model a, 3x model b, 1x model c and 1x model d
2965+
all_snapshot_ids = _all_snapshot_ids(sqlmesh)
2966+
assert len(all_snapshot_ids) == 7
2967+
assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2
2968+
assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 3
2969+
assert len([s for s in all_snapshot_ids if "model_c" in s.name]) == 1
2970+
assert len([s for s in all_snapshot_ids if "model_d" in s.name]) == 1
29432971

2944-
# context just has the two latest
2945-
assert len(sqlmesh.snapshots) == 2
2972+
# context just has the 4 latest
2973+
assert len(sqlmesh.snapshots) == 4
29462974

29472975
# model a expiry should not have changed
29482976
model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
@@ -2956,26 +2984,41 @@ def _mutate_config(gateway: str, config: Config):
29562984
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-05"
29572985
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-12"
29582986

2987+
# model c should expire at the same time as model b
2988+
model_c_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_c" in n)
2989+
assert to_ds(model_c_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts)
2990+
assert to_ds(model_c_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts)
2991+
2992+
# model d should expire at the same time as model b
2993+
model_d_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_d" in n)
2994+
assert to_ds(model_d_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts)
2995+
assert to_ds(model_d_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts)
2996+
29592997
# move forward to date where after model a has expired but before model b has expired
29602998
# invalidate dev to trigger cleanups
2961-
# run janitor. model a is expired so will be cleaned up and this will cascade to model b.
2999+
# run janitor
3000+
# - table model a is expired so will be cleaned up and this will cascade to view model b
3001+
# - view model b is not expired, but because it got cascaded to, this will cascade again to view model c
3002+
# - table model d is a not a view, so even though its parent view model b got dropped, it doesnt need to be dropped
29623003
with time_machine.travel("2020-01-10 00:00:00"):
29633004
sqlmesh = ctx.create_context(
29643005
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
29653006
)
29663007

2967-
before_snapshots = _state_sync_engine_adapter(sqlmesh).fetchall(
2968-
f"select name, identifier from sqlmesh._snapshots"
2969-
)
3008+
before_snapshot_ids = _all_snapshot_ids(sqlmesh)
3009+
29703010
sqlmesh.invalidate_environment("dev")
29713011
sqlmesh.run_janitor(ignore_ttl=False)
2972-
after_snapshots = _state_sync_engine_adapter(sqlmesh).fetchall(
2973-
f"select name, identifier from sqlmesh._snapshots"
2974-
)
29753012

2976-
assert len(before_snapshots) != len(after_snapshots)
3013+
after_snapshot_ids = _all_snapshot_ids(sqlmesh)
3014+
3015+
assert len(before_snapshot_ids) != len(after_snapshot_ids)
29773016

2978-
# all that's left should be the two snapshots that were in prod
2979-
assert set(
2980-
[SnapshotId(name=name, identifier=identifier) for name, identifier in after_snapshots]
2981-
) == set([model_a_prod_snapshot.snapshot_id, model_b_prod_snapshot.snapshot_id])
3017+
# all that's left should be the two original snapshots that were in prod and model d
3018+
assert set(after_snapshot_ids) == set(
3019+
[
3020+
model_a_prod_snapshot.snapshot_id,
3021+
model_b_prod_snapshot.snapshot_id,
3022+
model_d_snapshot.snapshot_id,
3023+
]
3024+
)

tests/core/state_sync/test_state_sync.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,85 @@ def test_get_expired_snapshots_includes_downstream_view_snapshots(
12081208
assert snapshot_b_cleanup
12091209

12101210

1211+
def test_get_expired_snapshots_includes_downstream_transitive_view_snapshots(
1212+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
1213+
):
1214+
now_ts = now_timestamp()
1215+
1216+
# model_a: table snapshot
1217+
snapshot_a = make_snapshot(
1218+
SqlModel(
1219+
name="a",
1220+
kind="FULL",
1221+
query=parse_one("select a, ds"),
1222+
),
1223+
ttl="in 10 seconds",
1224+
)
1225+
snapshot_a.updated_ts = now_ts - 15000 # now - 15 seconds = expired
1226+
assert not snapshot_a.parents
1227+
1228+
# model_b: view snapshot that depends on model_a table snapshot
1229+
snapshot_b = make_snapshot(
1230+
SqlModel(
1231+
name="b",
1232+
kind="VIEW",
1233+
depends_on=["a"],
1234+
query=parse_one("select *, 'model_b' as m from a"),
1235+
),
1236+
nodes={'"a"': snapshot_a.model},
1237+
ttl="in 10 seconds",
1238+
)
1239+
assert snapshot_b.parents == (snapshot_a.snapshot_id,)
1240+
1241+
# model_c: view snapshot that depends on model_b view snapshot (i.e no direct dependency on the model_a table ansphot).
1242+
# if the model_b view is dropped, model_c should also be dropped because it's a view that depends on a view being dropped
1243+
snapshot_c = make_snapshot(
1244+
SqlModel(
1245+
name="c",
1246+
kind="VIEW",
1247+
depends_on=["b"],
1248+
query=parse_one("select *, 'model_c' as m from b"),
1249+
),
1250+
nodes={'"a"': snapshot_a.model, '"b"': snapshot_b.model},
1251+
ttl="in 10 seconds",
1252+
)
1253+
assert snapshot_c.parents == (snapshot_b.snapshot_id,)
1254+
1255+
# model_d: table snapshot that depends on model_b view snapshot (i.e no direct dependency on the model_a table snapshot).
1256+
# if the model_b view is dropped, model_d should NOT be dropped because it can still be queried afterwards (and has not expired)
1257+
snapshot_d = make_snapshot(
1258+
SqlModel(
1259+
name="d",
1260+
kind="FULL",
1261+
depends_on=["b"],
1262+
query=parse_one("select *, 'model_d' as m from b"),
1263+
),
1264+
nodes={'"a"': snapshot_a.model, '"b"': snapshot_b.model},
1265+
ttl="in 10 seconds",
1266+
)
1267+
assert snapshot_d.parents == (snapshot_b.snapshot_id,)
1268+
1269+
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
1270+
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)
1271+
snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING)
1272+
snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING)
1273+
1274+
state_sync.push_snapshots([snapshot_a, snapshot_b, snapshot_c, snapshot_d])
1275+
1276+
cleanup_targets = state_sync.get_expired_snapshots(now_ts)
1277+
1278+
# snapshot_a should be cleaned up because it's expired
1279+
# snapshot_b is unexpired but should be cleaned up because snapshot_a is being cleaned up
1280+
# snapshot_c is unexpired and should be cleaned up because snapshot_b is being cleaned up
1281+
# snapshot_d is unexpired and should not be cleaned up. Although it depends on snapshot_b which is being cleaned up,
1282+
# its a table and not a view so will not be invalid the second snapshot_b is cleaned up
1283+
assert len(cleanup_targets) == 3
1284+
1285+
assert set(t.snapshot.snapshot_id for t in cleanup_targets) == set(
1286+
[snapshot_a.snapshot_id, snapshot_b.snapshot_id, snapshot_c.snapshot_id]
1287+
)
1288+
1289+
12111290
def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable):
12121291
now_ts = now_timestamp()
12131292

0 commit comments

Comments
 (0)