77import abc
88
99from dataclasses import dataclass
10+
11+ from pydantic_core .core_schema import ValidationInfo
1012from sqlglot import exp
1113
1214from sqlmesh .core .console import Console
1315from sqlmesh .core .dialect import schema_
14- from sqlmesh .utils .pydantic import PydanticModel
15- from sqlmesh .core .environment import Environment , EnvironmentStatements
16+ from sqlmesh .utils .pydantic import PydanticModel , field_validator
17+ from sqlmesh .core .environment import Environment , EnvironmentStatements , EnvironmentNamingInfo
1618from sqlmesh .utils .errors import SQLMeshError
17- from sqlmesh .core .snapshot import Snapshot , SnapshotEvaluator
19+ from sqlmesh .core .snapshot import (
20+ Snapshot ,
21+ SnapshotEvaluator ,
22+ SnapshotId ,
23+ SnapshotTableCleanupTask ,
24+ SnapshotTableInfo ,
25+ )
1826
1927if t .TYPE_CHECKING :
2028 from sqlmesh .core .engine_adapter .base import EngineAdapter
21- from sqlmesh .core .state_sync .base import Versions , ExpiredSnapshotBatch , StateReader , StateSync
29+ from sqlmesh .core .state_sync .base import Versions , StateReader , StateSync
2230
2331logger = logging .getLogger (__name__ )
2432
@@ -219,6 +227,70 @@ def __iter__(self) -> t.Iterator[StateStreamContents]:
219227 return _StateStream ()
220228
221229
230+ class ExpiredBatchRange (PydanticModel ):
231+ start : RowBoundary
232+ end : t .Union [RowBoundary , LimitBoundary ]
233+
234+ @classmethod
235+ def init_batch_range (cls , batch_size : int ) -> ExpiredBatchRange :
236+ return ExpiredBatchRange (
237+ start = RowBoundary .lowest_boundary (),
238+ end = LimitBoundary (batch_size = batch_size ),
239+ )
240+
241+ @classmethod
242+ def all_batch_range (cls ) -> ExpiredBatchRange :
243+ return ExpiredBatchRange (
244+ start = RowBoundary .lowest_boundary (),
245+ end = RowBoundary .highest_boundary (),
246+ )
247+
248+
249+ class RowBoundary (PydanticModel ):
250+ updated_ts : int
251+ name : str
252+ identifier : str
253+
254+ @classmethod
255+ def lowest_boundary (cls ) -> RowBoundary :
256+ return RowBoundary (updated_ts = 0 , name = "" , identifier = "" )
257+
258+ @classmethod
259+ def highest_boundary (cls ) -> RowBoundary :
260+ # 9999-12-31T23:59:59.999Z in epoch milliseconds
261+ return RowBoundary (updated_ts = 253_402_300_799_999 , name = "" , identifier = "" )
262+
263+
264+ class LimitBoundary (PydanticModel ):
265+ batch_size : int
266+
267+ @classmethod
268+ def init_batch_boundary (cls , batch_size : int ) -> LimitBoundary :
269+ return LimitBoundary (batch_size = batch_size )
270+
271+
272+ class PromotionResult (PydanticModel ):
273+ added : t .List [SnapshotTableInfo ]
274+ removed : t .List [SnapshotTableInfo ]
275+ removed_environment_naming_info : t .Optional [EnvironmentNamingInfo ]
276+
277+ @field_validator ("removed_environment_naming_info" )
278+ def _validate_removed_environment_naming_info (
279+ cls , v : t .Optional [EnvironmentNamingInfo ], info : ValidationInfo
280+ ) -> t .Optional [EnvironmentNamingInfo ]:
281+ if v and not info .data .get ("removed" ):
282+ raise ValueError ("removed_environment_naming_info must be None if removed is empty" )
283+ return v
284+
285+
286+ class ExpiredSnapshotBatch (PydanticModel ):
287+ """A batch of expired snapshots to be cleaned up."""
288+
289+ expired_snapshot_ids : t .Set [SnapshotId ]
290+ cleanup_tasks : t .List [SnapshotTableCleanupTask ]
291+ batch_range : ExpiredBatchRange
292+
293+
222294def iter_expired_snapshot_batches (
223295 state_reader : StateReader ,
224296 * ,
@@ -234,24 +306,29 @@ def iter_expired_snapshot_batches(
234306 ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced).
235307 batch_size: Maximum number of snapshots to fetch per batch.
236308 """
237- from sqlmesh .core .state_sync .base import LowerBatchBoundary
238309
239310 batch_size = batch_size if batch_size is not None else EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE
240- batch_boundary = LowerBatchBoundary . init_batch_boundary (batch_size = batch_size )
311+ batch_range = ExpiredBatchRange . init_batch_range (batch_size = batch_size )
241312
242313 while True :
243314 batch = state_reader .get_expired_snapshots (
244315 current_ts = current_ts ,
245316 ignore_ttl = ignore_ttl ,
246- batch_boundary = batch_boundary ,
317+ batch_range = batch_range ,
247318 )
248319
249320 if batch is None :
250321 return
251322
252323 yield batch
253324
254- batch_boundary = batch .batch_boundary .to_lower_batch_boundary (batch_size = batch_size )
325+ assert isinstance (batch .batch_range .end , RowBoundary ), (
326+ "Only RowBoundary is supported for pagination currently"
327+ )
328+ batch_range = ExpiredBatchRange (
329+ start = batch .batch_range .end ,
330+ end = LimitBoundary (batch_size = batch_size ),
331+ )
255332
256333
257334def delete_expired_snapshots (
@@ -286,17 +363,25 @@ def delete_expired_snapshots(
286363 ignore_ttl = ignore_ttl ,
287364 batch_size = batch_size ,
288365 ):
366+ end_info = (
367+ f"updated_ts={ batch .batch_range .end .updated_ts } "
368+ if isinstance (batch .batch_range .end , RowBoundary )
369+ else f"limit={ batch .batch_range .end .batch_size } "
370+ )
289371 logger .info (
290- "Processing batch of size %s and max_updated_ts of %s" ,
372+ "Processing batch of size %s with end %s" ,
291373 len (batch .expired_snapshot_ids ),
292- batch . batch_boundary . updated_ts ,
374+ end_info ,
293375 )
294376 snapshot_evaluator .cleanup (
295377 target_snapshots = batch .cleanup_tasks ,
296378 on_complete = console .update_cleanup_progress if console else None ,
297379 )
298380 state_sync .delete_expired_snapshots (
299- upper_batch_boundary = batch .batch_boundary .to_upper_batch_boundary (),
381+ batch_range = ExpiredBatchRange (
382+ start = RowBoundary .lowest_boundary (),
383+ end = batch .batch_range .end ,
384+ ),
300385 ignore_ttl = ignore_ttl ,
301386 )
302387 logger .info ("Cleaned up expired snapshots batch" )
0 commit comments