diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index bd8647f811..d118116f7f 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -107,9 +107,8 @@ CachingStateSync, StateReader, StateSync, - cleanup_expired_views, ) -from sqlmesh.core.state_sync.common import delete_expired_snapshots +from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, diff --git a/sqlmesh/core/janitor.py b/sqlmesh/core/janitor.py new file mode 100644 index 0000000000..e050d6ef6c --- /dev/null +++ b/sqlmesh/core/janitor.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.console import Console +from sqlmesh.core.dialect import schema_ +from sqlmesh.core.environment import Environment +from sqlmesh.core.snapshot import SnapshotEvaluator +from sqlmesh.core.state_sync import StateSync +from sqlmesh.core.state_sync.common import ( + logger, + iter_expired_snapshot_batches, + RowBoundary, + ExpiredBatchRange, +) +from sqlmesh.utils.errors import SQLMeshError + + +def cleanup_expired_views( + default_adapter: EngineAdapter, + engine_adapters: t.Dict[str, EngineAdapter], + environments: t.List[Environment], + warn_on_delete_failure: bool = False, + console: t.Optional[Console] = None, +) -> None: + expired_schema_or_catalog_environments = [ + environment + for environment in environments + if environment.suffix_target.is_schema or environment.suffix_target.is_catalog + ] + expired_table_environments = [ + environment for environment in environments if environment.suffix_target.is_table + ] + + # We have to use the corresponding adapter if the virtual layer is gateway managed + def get_adapter(gateway_managed: bool, gateway: t.Optional[str] = None) -> EngineAdapter: + if gateway_managed and gateway: + return engine_adapters.get(gateway, default_adapter) + return default_adapter + + catalogs_to_drop: t.Set[t.Tuple[EngineAdapter, str]] = set() + schemas_to_drop: t.Set[t.Tuple[EngineAdapter, exp.Table]] = set() + + # Collect schemas and catalogs to drop + for engine_adapter, expired_catalog, expired_schema, suffix_target in { + ( + (engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)), + snapshot.qualified_view_name.catalog_for_environment( + environment.naming_info, dialect=engine_adapter.dialect + ), + snapshot.qualified_view_name.schema_for_environment( + environment.naming_info, dialect=engine_adapter.dialect + ), + environment.suffix_target, + ) + for environment in expired_schema_or_catalog_environments + for snapshot in environment.snapshots + if snapshot.is_model and not snapshot.is_symbolic + }: + if suffix_target.is_catalog: + if expired_catalog: + catalogs_to_drop.add((engine_adapter, expired_catalog)) + else: + schema = schema_(expired_schema, expired_catalog) + schemas_to_drop.add((engine_adapter, schema)) + + # Drop the views for the expired environments + for engine_adapter, expired_view in { + ( + (engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)), + snapshot.qualified_view_name.for_environment( + environment.naming_info, dialect=engine_adapter.dialect + ), + ) + for environment in expired_table_environments + for snapshot in environment.snapshots + if snapshot.is_model and not snapshot.is_symbolic + }: + try: + engine_adapter.drop_view(expired_view, ignore_if_not_exists=True) + if console: + console.update_cleanup_progress(expired_view) + except Exception as e: + message = f"Failed to drop the expired environment view '{expired_view}': {e}" + if warn_on_delete_failure: + logger.warning(message) + else: + raise SQLMeshError(message) from e + + # Drop the schemas for the expired environments + for engine_adapter, schema in schemas_to_drop: + try: + engine_adapter.drop_schema( + schema, + ignore_if_not_exists=True, + cascade=True, + ) + if console: + console.update_cleanup_progress(schema.sql(dialect=engine_adapter.dialect)) + except Exception as e: + message = f"Failed to drop the expired environment schema '{schema}': {e}" + if warn_on_delete_failure: + logger.warning(message) + else: + raise SQLMeshError(message) from e + + # Drop any catalogs that were associated with a snapshot where the engine adapter supports dropping catalogs + # catalogs_to_drop is only populated when environment_suffix_target is set to 'catalog' + for engine_adapter, catalog in catalogs_to_drop: + if engine_adapter.SUPPORTS_CREATE_DROP_CATALOG: + try: + engine_adapter.drop_catalog(catalog) + if console: + console.update_cleanup_progress(catalog) + except Exception as e: + message = f"Failed to drop the expired environment catalog '{catalog}': {e}" + if warn_on_delete_failure: + logger.warning(message) + else: + raise SQLMeshError(message) from e + + +def delete_expired_snapshots( + state_sync: StateSync, + snapshot_evaluator: SnapshotEvaluator, + *, + current_ts: int, + ignore_ttl: bool = False, + batch_size: t.Optional[int] = None, + console: t.Optional[Console] = None, +) -> None: + """Delete all expired snapshots in batches. + + This helper function encapsulates the logic for deleting expired snapshots in batches, + eliminating code duplication across different use cases. + + Args: + state_sync: StateSync instance to query and delete expired snapshots from. + snapshot_evaluator: SnapshotEvaluator instance to clean up tables associated with snapshots. + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_size: Maximum number of snapshots to fetch per batch. + console: Optional console for reporting progress. + + Returns: + The total number of deleted expired snapshots. + """ + num_expired_snapshots = 0 + for batch in iter_expired_snapshot_batches( + state_reader=state_sync, + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_size=batch_size, + ): + end_info = ( + f"updated_ts={batch.batch_range.end.updated_ts}" + if isinstance(batch.batch_range.end, RowBoundary) + else f"limit={batch.batch_range.end.batch_size}" + ) + logger.info( + "Processing batch of size %s with end %s", + len(batch.expired_snapshot_ids), + end_info, + ) + snapshot_evaluator.cleanup( + target_snapshots=batch.cleanup_tasks, + on_complete=console.update_cleanup_progress if console else None, + ) + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=batch.batch_range.end, + ), + ignore_ttl=ignore_ttl, + ) + logger.info("Cleaned up expired snapshots batch") + num_expired_snapshots += len(batch.expired_snapshot_ids) + logger.info("Cleaned up %s expired snapshots", num_expired_snapshots) diff --git a/sqlmesh/core/state_sync/__init__.py b/sqlmesh/core/state_sync/__init__.py index 1585d6211f..12ea77ac8f 100644 --- a/sqlmesh/core/state_sync/__init__.py +++ b/sqlmesh/core/state_sync/__init__.py @@ -20,5 +20,4 @@ Versions as Versions, ) from sqlmesh.core.state_sync.cache import CachingStateSync as CachingStateSync -from sqlmesh.core.state_sync.common import cleanup_expired_views as cleanup_expired_views from sqlmesh.core.state_sync.db import EngineAdapterStateSync as EngineAdapterStateSync diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index 3fdd0bc015..056565b060 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -11,132 +11,23 @@ from pydantic_core.core_schema import ValidationInfo from sqlglot import exp -from sqlmesh.core.console import Console -from sqlmesh.core.dialect import schema_ from sqlmesh.utils.pydantic import PydanticModel, field_validator from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentNamingInfo -from sqlmesh.utils.errors import SQLMeshError from sqlmesh.core.snapshot import ( Snapshot, - SnapshotEvaluator, SnapshotId, SnapshotTableCleanupTask, SnapshotTableInfo, ) if t.TYPE_CHECKING: - from sqlmesh.core.engine_adapter.base import EngineAdapter - from sqlmesh.core.state_sync.base import Versions, StateReader, StateSync + from sqlmesh.core.state_sync.base import Versions, StateReader logger = logging.getLogger(__name__) EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE = 200 -def cleanup_expired_views( - default_adapter: EngineAdapter, - engine_adapters: t.Dict[str, EngineAdapter], - environments: t.List[Environment], - warn_on_delete_failure: bool = False, - console: t.Optional[Console] = None, -) -> None: - expired_schema_or_catalog_environments = [ - environment - for environment in environments - if environment.suffix_target.is_schema or environment.suffix_target.is_catalog - ] - expired_table_environments = [ - environment for environment in environments if environment.suffix_target.is_table - ] - - # We have to use the corresponding adapter if the virtual layer is gateway managed - def get_adapter(gateway_managed: bool, gateway: t.Optional[str] = None) -> EngineAdapter: - if gateway_managed and gateway: - return engine_adapters.get(gateway, default_adapter) - return default_adapter - - catalogs_to_drop: t.Set[t.Tuple[EngineAdapter, str]] = set() - schemas_to_drop: t.Set[t.Tuple[EngineAdapter, exp.Table]] = set() - - # Collect schemas and catalogs to drop - for engine_adapter, expired_catalog, expired_schema, suffix_target in { - ( - (engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)), - snapshot.qualified_view_name.catalog_for_environment( - environment.naming_info, dialect=engine_adapter.dialect - ), - snapshot.qualified_view_name.schema_for_environment( - environment.naming_info, dialect=engine_adapter.dialect - ), - environment.suffix_target, - ) - for environment in expired_schema_or_catalog_environments - for snapshot in environment.snapshots - if snapshot.is_model and not snapshot.is_symbolic - }: - if suffix_target.is_catalog: - if expired_catalog: - catalogs_to_drop.add((engine_adapter, expired_catalog)) - else: - schema = schema_(expired_schema, expired_catalog) - schemas_to_drop.add((engine_adapter, schema)) - - # Drop the views for the expired environments - for engine_adapter, expired_view in { - ( - (engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)), - snapshot.qualified_view_name.for_environment( - environment.naming_info, dialect=engine_adapter.dialect - ), - ) - for environment in expired_table_environments - for snapshot in environment.snapshots - if snapshot.is_model and not snapshot.is_symbolic - }: - try: - engine_adapter.drop_view(expired_view, ignore_if_not_exists=True) - if console: - console.update_cleanup_progress(expired_view) - except Exception as e: - message = f"Failed to drop the expired environment view '{expired_view}': {e}" - if warn_on_delete_failure: - logger.warning(message) - else: - raise SQLMeshError(message) from e - - # Drop the schemas for the expired environments - for engine_adapter, schema in schemas_to_drop: - try: - engine_adapter.drop_schema( - schema, - ignore_if_not_exists=True, - cascade=True, - ) - if console: - console.update_cleanup_progress(schema.sql(dialect=engine_adapter.dialect)) - except Exception as e: - message = f"Failed to drop the expired environment schema '{schema}': {e}" - if warn_on_delete_failure: - logger.warning(message) - else: - raise SQLMeshError(message) from e - - # Drop any catalogs that were associated with a snapshot where the engine adapter supports dropping catalogs - # catalogs_to_drop is only populated when environment_suffix_target is set to 'catalog' - for engine_adapter, catalog in catalogs_to_drop: - if engine_adapter.SUPPORTS_CREATE_DROP_CATALOG: - try: - engine_adapter.drop_catalog(catalog) - if console: - console.update_cleanup_progress(catalog) - except Exception as e: - message = f"Failed to drop the expired environment catalog '{catalog}': {e}" - if warn_on_delete_failure: - logger.warning(message) - else: - raise SQLMeshError(message) from e - - def transactional() -> t.Callable[[t.Callable], t.Callable]: def decorator(func: t.Callable) -> t.Callable: @wraps(func) @@ -429,61 +320,3 @@ def iter_expired_snapshot_batches( start=batch.batch_range.end, end=LimitBoundary(batch_size=batch_size), ) - - -def delete_expired_snapshots( - state_sync: StateSync, - snapshot_evaluator: SnapshotEvaluator, - *, - current_ts: int, - ignore_ttl: bool = False, - batch_size: t.Optional[int] = None, - console: t.Optional[Console] = None, -) -> None: - """Delete all expired snapshots in batches. - - This helper function encapsulates the logic for deleting expired snapshots in batches, - eliminating code duplication across different use cases. - - Args: - state_sync: StateSync instance to query and delete expired snapshots from. - snapshot_evaluator: SnapshotEvaluator instance to clean up tables associated with snapshots. - current_ts: Timestamp used to evaluate expiration. - ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). - batch_size: Maximum number of snapshots to fetch per batch. - console: Optional console for reporting progress. - - Returns: - The total number of deleted expired snapshots. - """ - num_expired_snapshots = 0 - for batch in iter_expired_snapshot_batches( - state_reader=state_sync, - current_ts=current_ts, - ignore_ttl=ignore_ttl, - batch_size=batch_size, - ): - end_info = ( - f"updated_ts={batch.batch_range.end.updated_ts}" - if isinstance(batch.batch_range.end, RowBoundary) - else f"limit={batch.batch_range.end.batch_size}" - ) - logger.info( - "Processing batch of size %s with end %s", - len(batch.expired_snapshot_ids), - end_info, - ) - snapshot_evaluator.cleanup( - target_snapshots=batch.cleanup_tasks, - on_complete=console.update_cleanup_progress if console else None, - ) - state_sync.delete_expired_snapshots( - batch_range=ExpiredBatchRange( - start=RowBoundary.lowest_boundary(), - end=batch.batch_range.end, - ), - ignore_ttl=ignore_ttl, - ) - logger.info("Cleaned up expired snapshots batch") - num_expired_snapshots += len(batch.expired_snapshot_ids) - logger.info("Cleaned up %s expired snapshots", num_expired_snapshots) diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index 199ca43ee9..bd01dfc652 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -13,19 +13,17 @@ from sqlmesh.core import constants as c from sqlmesh.core.config import EnvironmentSuffixTarget -from sqlmesh.core.dialect import parse_one, schema_ +from sqlmesh.core.dialect import parse_one from sqlmesh.core.engine_adapter import create_engine_adapter from sqlmesh.core.environment import Environment, EnvironmentStatements from sqlmesh.core.model import ( FullKind, IncrementalByTimeRangeKind, - ModelKindName, Seed, SeedKind, SeedModel, SqlModel, ) -from sqlmesh.core.model.definition import ExternalModel from sqlmesh.core.snapshot import ( Snapshot, SnapshotChangeCategory, @@ -38,7 +36,6 @@ from sqlmesh.core.state_sync import ( CachingStateSync, EngineAdapterStateSync, - cleanup_expired_views, ) from sqlmesh.core.state_sync.base import ( SCHEMA_VERSION, @@ -1524,154 +1521,6 @@ def test_expired_batch_range_where_filter_with_limit(): ) -def test_delete_expired_snapshots_common_function_batching( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture -): - """Test that the common delete_expired_snapshots function properly pages through batches and deletes them.""" - from sqlmesh.core.state_sync.common import delete_expired_snapshots - from sqlmesh.core.state_sync.common import ExpiredBatchRange, RowBoundary, LimitBoundary - from unittest.mock import MagicMock - - now_ts = now_timestamp() - - # Create 5 expired snapshots with different timestamps - snapshots = [] - for idx in range(5): - snapshot = make_snapshot( - SqlModel( - name=f"model_{idx}", - query=parse_one("select 1 as a, ds"), - ), - ) - snapshot.ttl = "in 10 seconds" - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = now_ts - (20000 + idx * 1000) - snapshots.append(snapshot) - - state_sync.push_snapshots(snapshots) - - # Spy on get_expired_snapshots and delete_expired_snapshots methods - get_expired_spy = mocker.spy(state_sync, "get_expired_snapshots") - delete_expired_spy = mocker.spy(state_sync, "delete_expired_snapshots") - - # Mock snapshot evaluator - mock_evaluator = MagicMock() - mock_evaluator.cleanup = MagicMock() - - # Run delete_expired_snapshots with batch_size=2 - delete_expired_snapshots( - state_sync, - mock_evaluator, - current_ts=now_ts, - batch_size=2, - ) - - # Verify get_expired_snapshots was called the correct number of times: - # - 3 batches (2+2+1): each batch triggers 2 calls (one from iter_expired_snapshot_batches, one from delete_expired_snapshots) - # - Plus 1 final call that returns empty to exit the loop - # Total: 3 * 2 + 1 = 7 calls - assert get_expired_spy.call_count == 7 - - # Verify the progression of batch_range calls from the iter_expired_snapshot_batches loop - # (calls at indices 0, 2, 4, 6 are from iter_expired_snapshot_batches) - # (calls at indices 1, 3, 5 are from delete_expired_snapshots in facade.py) - calls = get_expired_spy.call_args_list - - # First call from iterator should have a batch_range starting from the beginning - first_call_kwargs = calls[0][1] - assert "batch_range" in first_call_kwargs - first_range = first_call_kwargs["batch_range"] - assert isinstance(first_range, ExpiredBatchRange) - assert isinstance(first_range.start, RowBoundary) - assert isinstance(first_range.end, LimitBoundary) - assert first_range.end.batch_size == 2 - assert first_range.start.updated_ts == 0 - assert first_range.start.name == "" - assert first_range.start.identifier == "" - - # Third call (second batch from iterator) should have a batch_range from the first batch's range - third_call_kwargs = calls[2][1] - assert "batch_range" in third_call_kwargs - second_range = third_call_kwargs["batch_range"] - assert isinstance(second_range, ExpiredBatchRange) - assert isinstance(second_range.start, RowBoundary) - assert isinstance(second_range.end, LimitBoundary) - assert second_range.end.batch_size == 2 - # Should have progressed from the first batch - assert second_range.start.updated_ts > 0 - assert second_range.start.name == '"model_3"' - - # Fifth call (third batch from iterator) should have a batch_range from the second batch's range - fifth_call_kwargs = calls[4][1] - assert "batch_range" in fifth_call_kwargs - third_range = fifth_call_kwargs["batch_range"] - assert isinstance(third_range, ExpiredBatchRange) - assert isinstance(third_range.start, RowBoundary) - assert isinstance(third_range.end, LimitBoundary) - assert third_range.end.batch_size == 2 - # Should have progressed from the second batch - assert third_range.start.updated_ts >= second_range.start.updated_ts - assert third_range.start.name == '"model_1"' - - # Seventh call (final call from iterator) should have a batch_range from the third batch's range - seventh_call_kwargs = calls[6][1] - assert "batch_range" in seventh_call_kwargs - fourth_range = seventh_call_kwargs["batch_range"] - assert isinstance(fourth_range, ExpiredBatchRange) - assert isinstance(fourth_range.start, RowBoundary) - assert isinstance(fourth_range.end, LimitBoundary) - assert fourth_range.end.batch_size == 2 - # Should have progressed from the third batch - assert fourth_range.start.updated_ts >= third_range.start.updated_ts - assert fourth_range.start.name == '"model_0"' - - # Verify delete_expired_snapshots was called 3 times (once per batch) - assert delete_expired_spy.call_count == 3 - - # Verify each delete call used a batch_range - delete_calls = delete_expired_spy.call_args_list - - # First call should have a batch_range matching the first batch - first_delete_kwargs = delete_calls[0][1] - assert "batch_range" in first_delete_kwargs - first_delete_range = first_delete_kwargs["batch_range"] - assert isinstance(first_delete_range, ExpiredBatchRange) - assert isinstance(first_delete_range.start, RowBoundary) - assert first_delete_range.start.updated_ts == 0 - assert isinstance(first_delete_range.end, RowBoundary) - assert first_delete_range.end.updated_ts == second_range.start.updated_ts - assert first_delete_range.end.name == second_range.start.name - assert first_delete_range.end.identifier == second_range.start.identifier - - second_delete_kwargs = delete_calls[1][1] - assert "batch_range" in second_delete_kwargs - second_delete_range = second_delete_kwargs["batch_range"] - assert isinstance(second_delete_range, ExpiredBatchRange) - assert isinstance(second_delete_range.start, RowBoundary) - assert second_delete_range.start.updated_ts == 0 - assert isinstance(second_delete_range.end, RowBoundary) - assert second_delete_range.end.updated_ts == third_range.start.updated_ts - assert second_delete_range.end.name == third_range.start.name - assert second_delete_range.end.identifier == third_range.start.identifier - - third_delete_kwargs = delete_calls[2][1] - assert "batch_range" in third_delete_kwargs - third_delete_range = third_delete_kwargs["batch_range"] - assert isinstance(third_delete_range, ExpiredBatchRange) - assert isinstance(third_delete_range.start, RowBoundary) - assert third_delete_range.start.updated_ts == 0 - assert isinstance(third_delete_range.end, RowBoundary) - assert third_delete_range.end.updated_ts == fourth_range.start.updated_ts - assert third_delete_range.end.name == fourth_range.start.name - assert third_delete_range.end.identifier == fourth_range.start.identifier - # Verify the cleanup method was called for each batch that had cleanup tasks - assert mock_evaluator.cleanup.call_count >= 1 - - # Verify all snapshots were deleted in the end - remaining = state_sync.get_snapshots(snapshots) - assert len(remaining) == 0 - - def test_delete_expired_snapshots_seed( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ): @@ -3089,105 +2938,6 @@ def test_cache(state_sync, make_snapshot, mocker): mock.assert_called() -def test_cleanup_expired_views( - mocker: MockerFixture, state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - adapter = mocker.MagicMock() - adapter.dialect = None - snapshot_a = make_snapshot(SqlModel(name="catalog.schema.a", query=parse_one("select 1, ds"))) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_b = make_snapshot(SqlModel(name="catalog.schema.b", query=parse_one("select 1, ds"))) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - # Make sure that we don't drop schemas from external models - snapshot_external_model = make_snapshot( - ExternalModel(name="catalog.external_schema.external_table", kind=ModelKindName.EXTERNAL) - ) - snapshot_external_model.categorize_as(SnapshotChangeCategory.BREAKING) - schema_environment = Environment( - name="test_environment", - suffix_target=EnvironmentSuffixTarget.SCHEMA, - snapshots=[ - snapshot_a.table_info, - snapshot_b.table_info, - snapshot_external_model.table_info, - ], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - catalog_name_override="catalog_override", - ) - snapshot_c = make_snapshot(SqlModel(name="catalog.schema.c", query=parse_one("select 1, ds"))) - snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_d = make_snapshot(SqlModel(name="catalog.schema.d", query=parse_one("select 1, ds"))) - snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) - table_environment = Environment( - name="test_environment", - suffix_target=EnvironmentSuffixTarget.TABLE, - snapshots=[ - snapshot_c.table_info, - snapshot_d.table_info, - snapshot_external_model.table_info, - ], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - catalog_name_override="catalog_override", - ) - cleanup_expired_views(adapter, {}, [schema_environment, table_environment]) - assert adapter.drop_schema.called - assert adapter.drop_view.called - assert adapter.drop_schema.call_args_list == [ - call( - schema_("schema__test_environment", "catalog_override"), - ignore_if_not_exists=True, - cascade=True, - ) - ] - assert sorted(adapter.drop_view.call_args_list) == [ - call("catalog_override.schema.c__test_environment", ignore_if_not_exists=True), - call("catalog_override.schema.d__test_environment", ignore_if_not_exists=True), - ] - - -@pytest.mark.parametrize( - "suffix_target", [EnvironmentSuffixTarget.SCHEMA, EnvironmentSuffixTarget.TABLE] -) -def test_cleanup_expired_environment_schema_warn_on_delete_failure( - mocker: MockerFixture, make_snapshot: t.Callable, suffix_target: EnvironmentSuffixTarget -): - adapter = mocker.MagicMock() - adapter.dialect = None - adapter.drop_schema.side_effect = Exception("Failed to drop the schema") - adapter.drop_view.side_effect = Exception("Failed to drop the view") - - snapshot = make_snapshot( - SqlModel(name="test_catalog.test_schema.test_model", query=parse_one("select 1, ds")) - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - schema_environment = Environment( - name="test_environment", - suffix_target=suffix_target, - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - catalog_name_override="catalog_override", - ) - - with pytest.raises(SQLMeshError, match="Failed to drop the expired environment .*"): - cleanup_expired_views(adapter, {}, [schema_environment], warn_on_delete_failure=False) - - cleanup_expired_views(adapter, {}, [schema_environment], warn_on_delete_failure=True) - - if suffix_target == EnvironmentSuffixTarget.SCHEMA: - assert adapter.drop_schema.called - else: - assert adapter.drop_view.called - - def test_max_interval_end_per_model( state_sync: EngineAdapterStateSync, make_snapshot: t.Callable ) -> None: diff --git a/tests/core/test_janitor.py b/tests/core/test_janitor.py new file mode 100644 index 0000000000..e5e209f2cc --- /dev/null +++ b/tests/core/test_janitor.py @@ -0,0 +1,282 @@ +import typing as t +from unittest.mock import call + +import pytest +from pytest_mock.plugin import MockerFixture + +from sqlmesh.core.config import EnvironmentSuffixTarget +from sqlmesh.core import constants as c +from sqlmesh.core.dialect import parse_one, schema_ +from sqlmesh.core.engine_adapter import create_engine_adapter +from sqlmesh.core.environment import Environment +from sqlmesh.core.model import ( + ModelKindName, + SqlModel, +) +from sqlmesh.core.model.definition import ExternalModel +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.core.state_sync import ( + EngineAdapterStateSync, +) +from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots +from sqlmesh.utils.date import now_timestamp +from sqlmesh.utils.errors import SQLMeshError + +pytestmark = pytest.mark.slow + + +@pytest.fixture +def state_sync(duck_conn, tmp_path): + state_sync = EngineAdapterStateSync( + create_engine_adapter(lambda: duck_conn, "duckdb"), + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + state_sync.migrate() + return state_sync + + +def test_cleanup_expired_views(mocker: MockerFixture, make_snapshot: t.Callable): + adapter = mocker.MagicMock() + adapter.dialect = None + snapshot_a = make_snapshot(SqlModel(name="catalog.schema.a", query=parse_one("select 1, ds"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b = make_snapshot(SqlModel(name="catalog.schema.b", query=parse_one("select 1, ds"))) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + # Make sure that we don't drop schemas from external models + snapshot_external_model = make_snapshot( + ExternalModel(name="catalog.external_schema.external_table", kind=ModelKindName.EXTERNAL) + ) + snapshot_external_model.categorize_as(SnapshotChangeCategory.BREAKING) + schema_environment = Environment( + name="test_environment", + suffix_target=EnvironmentSuffixTarget.SCHEMA, + snapshots=[ + snapshot_a.table_info, + snapshot_b.table_info, + snapshot_external_model.table_info, + ], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + catalog_name_override="catalog_override", + ) + snapshot_c = make_snapshot(SqlModel(name="catalog.schema.c", query=parse_one("select 1, ds"))) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_d = make_snapshot(SqlModel(name="catalog.schema.d", query=parse_one("select 1, ds"))) + snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) + table_environment = Environment( + name="test_environment", + suffix_target=EnvironmentSuffixTarget.TABLE, + snapshots=[ + snapshot_c.table_info, + snapshot_d.table_info, + snapshot_external_model.table_info, + ], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + catalog_name_override="catalog_override", + ) + cleanup_expired_views(adapter, {}, [schema_environment, table_environment]) + assert adapter.drop_schema.called + assert adapter.drop_view.called + assert adapter.drop_schema.call_args_list == [ + call( + schema_("schema__test_environment", "catalog_override"), + ignore_if_not_exists=True, + cascade=True, + ) + ] + assert sorted(adapter.drop_view.call_args_list) == [ + call("catalog_override.schema.c__test_environment", ignore_if_not_exists=True), + call("catalog_override.schema.d__test_environment", ignore_if_not_exists=True), + ] + + +@pytest.mark.parametrize( + "suffix_target", [EnvironmentSuffixTarget.SCHEMA, EnvironmentSuffixTarget.TABLE] +) +def test_cleanup_expired_environment_schema_warn_on_delete_failure( + mocker: MockerFixture, make_snapshot: t.Callable, suffix_target: EnvironmentSuffixTarget +): + adapter = mocker.MagicMock() + adapter.dialect = None + adapter.drop_schema.side_effect = Exception("Failed to drop the schema") + adapter.drop_view.side_effect = Exception("Failed to drop the view") + + snapshot = make_snapshot( + SqlModel(name="test_catalog.test_schema.test_model", query=parse_one("select 1, ds")) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + schema_environment = Environment( + name="test_environment", + suffix_target=suffix_target, + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + catalog_name_override="catalog_override", + ) + + with pytest.raises(SQLMeshError, match="Failed to drop the expired environment .*"): + cleanup_expired_views(adapter, {}, [schema_environment], warn_on_delete_failure=False) + + cleanup_expired_views(adapter, {}, [schema_environment], warn_on_delete_failure=True) + + if suffix_target == EnvironmentSuffixTarget.SCHEMA: + assert adapter.drop_schema.called + else: + assert adapter.drop_view.called + + +def test_delete_expired_snapshots_common_function_batching( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture +): + """Test that the common delete_expired_snapshots function properly pages through batches and deletes them.""" + from sqlmesh.core.state_sync.common import ExpiredBatchRange, RowBoundary, LimitBoundary + from unittest.mock import MagicMock + + now_ts = now_timestamp() + + # Create 5 expired snapshots with different timestamps + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Spy on get_expired_snapshots and delete_expired_snapshots methods + get_expired_spy = mocker.spy(state_sync, "get_expired_snapshots") + delete_expired_spy = mocker.spy(state_sync, "delete_expired_snapshots") + + # Mock snapshot evaluator + mock_evaluator = MagicMock() + mock_evaluator.cleanup = MagicMock() + + # Run delete_expired_snapshots with batch_size=2 + delete_expired_snapshots( + state_sync, + mock_evaluator, + current_ts=now_ts, + batch_size=2, + ) + + # Verify get_expired_snapshots was called the correct number of times: + # - 3 batches (2+2+1): each batch triggers 2 calls (one from iter_expired_snapshot_batches, one from delete_expired_snapshots) + # - Plus 1 final call that returns empty to exit the loop + # Total: 3 * 2 + 1 = 7 calls + assert get_expired_spy.call_count == 7 + + # Verify the progression of batch_range calls from the iter_expired_snapshot_batches loop + # (calls at indices 0, 2, 4, 6 are from iter_expired_snapshot_batches) + # (calls at indices 1, 3, 5 are from delete_expired_snapshots in facade.py) + calls = get_expired_spy.call_args_list + + # First call from iterator should have a batch_range starting from the beginning + first_call_kwargs = calls[0][1] + assert "batch_range" in first_call_kwargs + first_range = first_call_kwargs["batch_range"] + assert isinstance(first_range, ExpiredBatchRange) + assert isinstance(first_range.start, RowBoundary) + assert isinstance(first_range.end, LimitBoundary) + assert first_range.end.batch_size == 2 + assert first_range.start.updated_ts == 0 + assert first_range.start.name == "" + assert first_range.start.identifier == "" + + # Third call (second batch from iterator) should have a batch_range from the first batch's range + third_call_kwargs = calls[2][1] + assert "batch_range" in third_call_kwargs + second_range = third_call_kwargs["batch_range"] + assert isinstance(second_range, ExpiredBatchRange) + assert isinstance(second_range.start, RowBoundary) + assert isinstance(second_range.end, LimitBoundary) + assert second_range.end.batch_size == 2 + # Should have progressed from the first batch + assert second_range.start.updated_ts > 0 + assert second_range.start.name == '"model_3"' + + # Fifth call (third batch from iterator) should have a batch_range from the second batch's range + fifth_call_kwargs = calls[4][1] + assert "batch_range" in fifth_call_kwargs + third_range = fifth_call_kwargs["batch_range"] + assert isinstance(third_range, ExpiredBatchRange) + assert isinstance(third_range.start, RowBoundary) + assert isinstance(third_range.end, LimitBoundary) + assert third_range.end.batch_size == 2 + # Should have progressed from the second batch + assert third_range.start.updated_ts >= second_range.start.updated_ts + assert third_range.start.name == '"model_1"' + + # Seventh call (final call from iterator) should have a batch_range from the third batch's range + seventh_call_kwargs = calls[6][1] + assert "batch_range" in seventh_call_kwargs + fourth_range = seventh_call_kwargs["batch_range"] + assert isinstance(fourth_range, ExpiredBatchRange) + assert isinstance(fourth_range.start, RowBoundary) + assert isinstance(fourth_range.end, LimitBoundary) + assert fourth_range.end.batch_size == 2 + # Should have progressed from the third batch + assert fourth_range.start.updated_ts >= third_range.start.updated_ts + assert fourth_range.start.name == '"model_0"' + + # Verify delete_expired_snapshots was called 3 times (once per batch) + assert delete_expired_spy.call_count == 3 + + # Verify each delete call used a batch_range + delete_calls = delete_expired_spy.call_args_list + + # First call should have a batch_range matching the first batch + first_delete_kwargs = delete_calls[0][1] + assert "batch_range" in first_delete_kwargs + first_delete_range = first_delete_kwargs["batch_range"] + assert isinstance(first_delete_range, ExpiredBatchRange) + assert isinstance(first_delete_range.start, RowBoundary) + assert first_delete_range.start.updated_ts == 0 + assert isinstance(first_delete_range.end, RowBoundary) + assert first_delete_range.end.updated_ts == second_range.start.updated_ts + assert first_delete_range.end.name == second_range.start.name + assert first_delete_range.end.identifier == second_range.start.identifier + + second_delete_kwargs = delete_calls[1][1] + assert "batch_range" in second_delete_kwargs + second_delete_range = second_delete_kwargs["batch_range"] + assert isinstance(second_delete_range, ExpiredBatchRange) + assert isinstance(second_delete_range.start, RowBoundary) + assert second_delete_range.start.updated_ts == 0 + assert isinstance(second_delete_range.end, RowBoundary) + assert second_delete_range.end.updated_ts == third_range.start.updated_ts + assert second_delete_range.end.name == third_range.start.name + assert second_delete_range.end.identifier == third_range.start.identifier + + third_delete_kwargs = delete_calls[2][1] + assert "batch_range" in third_delete_kwargs + third_delete_range = third_delete_kwargs["batch_range"] + assert isinstance(third_delete_range, ExpiredBatchRange) + assert isinstance(third_delete_range.start, RowBoundary) + assert third_delete_range.start.updated_ts == 0 + assert isinstance(third_delete_range.end, RowBoundary) + assert third_delete_range.end.updated_ts == fourth_range.start.updated_ts + assert third_delete_range.end.name == fourth_range.start.name + assert third_delete_range.end.identifier == fourth_range.start.identifier + # Verify the cleanup method was called for each batch that had cleanup tasks + assert mock_evaluator.cleanup.call_count >= 1 + + # Verify all snapshots were deleted in the end + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 0