Skip to content

Commit 7982c31

Browse files
committed
Refactor logic into a built-in signal
1 parent 9a76625 commit 7982c31

File tree

6 files changed

+138
-50
lines changed

6 files changed

+138
-50
lines changed

.circleci/continue_config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ workflows:
310310
- athena
311311
- fabric
312312
- gcp-postgres
313-
filters:
314-
branches:
315-
only:
316-
- main
313+
# filters:
314+
# branches:
315+
# only:
316+
# - main
317317
- ui_style
318318
- ui_test
319319
- vscode_test

sqlmesh/core/scheduler.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,6 @@ def evaluate(
267267

268268
snapshots = parent_snapshots_by_name(snapshot, self.snapshots)
269269

270-
if not is_restatement_plan and self.can_skip_evaluation(snapshot, snapshots):
271-
logger.info(f"""
272-
Skipping evaluation for snapshot {snapshot.name} as it depends on external models
273-
that have not been updated since the last run.
274-
""")
275-
276-
return []
277-
278270
is_deployable = deployability_index.is_deployable(snapshot)
279271

280272
wap_id = self.snapshot_evaluator.evaluate(
@@ -388,6 +380,7 @@ def batch_intervals(
388380
deployability_index: t.Optional[DeployabilityIndex],
389381
environment_naming_info: EnvironmentNamingInfo,
390382
dag: t.Optional[DAG[SnapshotId]] = None,
383+
is_restatement_plan: bool = False,
391384
) -> t.Dict[Snapshot, Intervals]:
392385
dag = dag or snapshots_to_dag(merged_intervals)
393386

@@ -427,6 +420,7 @@ def batch_intervals(
427420
intervals,
428421
context,
429422
environment_naming_info,
423+
is_restatement_plan=is_restatement_plan,
430424
)
431425
unready -= set(intervals)
432426

@@ -509,9 +503,12 @@ def run_merged_intervals(
509503
snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set)
510504

511505
batched_intervals = self.batch_intervals(
512-
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
506+
merged_intervals,
507+
deployability_index,
508+
environment_naming_info,
509+
dag=snapshot_dag,
510+
is_restatement_plan=is_restatement_plan,
513511
)
514-
515512
self.console.start_evaluation_progress(
516513
batched_intervals,
517514
environment_naming_info,
@@ -968,6 +965,7 @@ def _check_ready_intervals(
968965
intervals: Intervals,
969966
context: ExecutionContext,
970967
environment_naming_info: EnvironmentNamingInfo,
968+
is_restatement_plan: bool = False,
971969
) -> Intervals:
972970
"""Checks if the intervals are ready for evaluation for the given snapshot.
973971
@@ -989,13 +987,27 @@ def _check_ready_intervals(
989987
if not (signals and signals.signals_to_kwargs):
990988
return intervals
991989

990+
signal_names = signals.signals_to_kwargs.keys()
991+
992+
if (
993+
is_restatement_plan
994+
and len(signal_names) == 1
995+
and next(iter(signal_names)) == "freshness"
996+
):
997+
# Freshness signal is not checked for restatement plans to allow users
998+
# for an escape hatch in reevaluating models
999+
return intervals
1000+
9921001
self.console.start_signal_progress(
9931002
snapshot,
9941003
self.default_catalog,
9951004
environment_naming_info or EnvironmentNamingInfo(),
9961005
)
9971006

9981007
for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()):
1008+
if is_restatement_plan and signal_name == "freshness":
1009+
continue
1010+
9991011
# Capture intervals before signal check for display
10001012
intervals_to_check = merge_intervals(intervals)
10011013

@@ -1009,6 +1021,7 @@ def _check_ready_intervals(
10091021
python_env=signals.python_env,
10101022
dialect=snapshot.model.dialect,
10111023
path=snapshot.model._path,
1024+
snapshot=snapshot,
10121025
kwargs=kwargs,
10131026
)
10141027
except SQLMeshError as e:

sqlmesh/core/signal.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from __future__ import annotations
22

3-
3+
import typing as t
44
from sqlmesh.utils import UniqueKeyDict, registry_decorator
55

6+
if t.TYPE_CHECKING:
7+
from sqlmesh.core.context import ExecutionContext
8+
from sqlmesh.core.snapshot.definition import Snapshot
9+
from sqlmesh.utils.date import DatetimeRanges
10+
611

712
class signal(registry_decorator):
813
"""Specifies a function which intervals are ready from a list of scheduled intervals.
@@ -33,3 +38,30 @@ class signal(registry_decorator):
3338

3439

3540
SignalRegistry = UniqueKeyDict[str, signal]
41+
42+
43+
@signal()
44+
def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
45+
adapter = context.engine_adapter
46+
if not snapshot.last_altered_ts or not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
47+
return True
48+
49+
adapter = context.engine_adapter
50+
parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
51+
if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
52+
p.is_external for p in parent_snapshots
53+
):
54+
# The mismatch can happen if e.g an external model is not registered in the project
55+
return True
56+
57+
# Finding new data means that the upstream depedencies have been altered
58+
# since the last time the model was evaluated
59+
upstream_dep_has_new_data = any(
60+
upstream_last_altered_ts > snapshot.last_altered_ts
61+
for upstream_last_altered_ts in adapter.get_external_model_freshness(
62+
[p.name for p in parent_snapshots]
63+
)
64+
)
65+
66+
# Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
67+
return upstream_dep_has_new_data

sqlmesh/core/snapshot/definition.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,9 @@ def hydrate_with_intervals_by_version(
756756
for interval in snapshot_intervals:
757757
snapshot.merge_intervals(interval)
758758

759-
if interval.last_altered_ts:
759+
# Differentiate last_altered_ts between snapshots with shared version but
760+
# different dev versions e.g prod vs FORWARD_ONLY dev
761+
if snapshot.dev_version == interval.dev_version and interval.last_altered_ts:
760762
snapshot.last_altered_ts = max(
761763
snapshot.last_altered_ts or -1, interval.last_altered_ts
762764
)
@@ -1091,6 +1093,7 @@ def check_ready_intervals(
10911093
python_env=signals.python_env,
10921094
dialect=self.model.dialect,
10931095
path=self.model._path,
1096+
snapshot=self,
10941097
kwargs=kwargs,
10951098
)
10961099
except SQLMeshError as e:
@@ -2431,6 +2434,7 @@ def check_ready_intervals(
24312434
python_env: t.Dict[str, Executable],
24322435
dialect: DialectType = None,
24332436
path: t.Optional[Path] = None,
2437+
snapshot: t.Optional[Snapshot] = None,
24342438
kwargs: t.Optional[t.Dict] = None,
24352439
) -> Intervals:
24362440
checked_intervals: Intervals = []
@@ -2446,6 +2450,7 @@ def check_ready_intervals(
24462450
provided_args=(batch,),
24472451
provided_kwargs=(kwargs or {}),
24482452
context=context,
2453+
snapshot=snapshot,
24492454
)
24502455
except Exception as ex:
24512456
raise SignalEvalError(format_evaluated_code_exception(ex, python_env))

sqlmesh/migrations/v0093_add_last_altered_to_intervals.py renamed to sqlmesh/migrations/v0098_add_last_altered_to_intervals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlglot import exp
44

55

6-
def migrate(state_sync, **kwargs): # type: ignore
6+
def migrate_schemas(state_sync, **kwargs): # type: ignore
77
engine_adapter = state_sync.engine_adapter
88
schema = state_sync.schema
99
intervals_table = "_intervals"

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from unittest import mock
1111
from unittest.mock import patch
1212
import logging
13+
from IPython.utils.capture import capture_output
14+
1315

1416
import time_machine
1517
from pytest_mock.plugin import MockerFixture
@@ -3846,36 +3848,45 @@ def test_external_model_freshness(ctx: TestContext, mocker: MockerFixture, tmp_p
38463848
if not adapter.SUPPORTS_EXTERNAL_MODEL_FRESHNESS:
38473849
pytest.skip("This test only runs for engines that support external model freshness")
38483850

3849-
def _run_plan(
3850-
sqlmesh_context: Context, restate_models: t.Optional[t.List[str]] = None
3851-
) -> PlanResults:
3852-
plan: Plan = sqlmesh_context.plan(
3853-
auto_apply=True, no_prompts=True, restate_models=restate_models
3851+
def _assert_snapshot_last_altered_ts(context: Context, snapshot_id: str, timestamp: datetime):
3852+
from sqlmesh.utils.date import to_datetime
3853+
3854+
snapshot = context.state_sync.get_snapshots([snapshot_id])[snapshot_id]
3855+
assert to_datetime(snapshot.last_altered_ts).replace(microsecond=0) == timestamp.replace(
3856+
microsecond=0
38543857
)
3855-
return PlanResults.create(plan, ctx, schema)
38563858

38573859
import sqlmesh
38583860

38593861
spy = mocker.spy(sqlmesh.core.snapshot.evaluator.SnapshotEvaluator, "evaluate")
38603862

38613863
def _assert_model_evaluation(lambda_func, was_evaluated, day_delta=0):
3862-
call_count_before = spy.call_count
3863-
logger = logging.getLogger("sqlmesh.core.scheduler")
3864-
3865-
with time_machine.travel(now(minute_floor=False) + timedelta(days=day_delta)):
3866-
with mock.patch.object(logger, "info") as mock_logger:
3867-
lambda_func()
3868-
3869-
evaluation_skipped_log = any(
3870-
"Skipping evaluation for snapshot" in call[0][0] for call in mock_logger.call_args_list
3871-
)
3872-
3864+
spy.reset_mock()
3865+
timestamp = now(minute_floor=False) + timedelta(days=day_delta)
3866+
with time_machine.travel(timestamp, tick=False):
3867+
with capture_output() as output:
3868+
plan_or_run_result = lambda_func()
3869+
3870+
evaluate_function_called = spy.call_count == 1
3871+
signal_was_checked = "Checking signals for" in output.stdout
3872+
restatement_plan = isinstance(plan_or_run_result, Plan) and plan_or_run_result.restatements
3873+
if restatement_plan:
3874+
# Restatement plans exclude this signal so we expect the actual evaluation
3875+
# to happen but not through the signal
3876+
assert evaluate_function_called
3877+
assert not signal_was_checked
3878+
return
3879+
3880+
# All other cases (e.g normal plans or runs) will check the freshness signal
3881+
assert signal_was_checked
38733882
if was_evaluated:
3874-
assert not evaluation_skipped_log
3875-
assert spy.call_count == call_count_before + 1
3883+
assert "All ready" in output.stdout
3884+
assert evaluate_function_called
38763885
else:
3877-
assert evaluation_skipped_log
3878-
assert spy.call_count == call_count_before
3886+
assert "None ready" in output.stdout
3887+
assert not evaluate_function_called
3888+
3889+
return timestamp, plan_or_run_result
38793890

38803891
# Create & initialize schema
38813892
schema = ctx.add_test_suffix(TEST_SCHEMA)
@@ -3912,7 +3923,10 @@ def _assert_model_evaluation(lambda_func, was_evaluated, day_delta=0):
39123923
MODEL (
39133924
name {model_name},
39143925
start '2024-01-01',
3915-
kind FULL
3926+
kind FULL,
3927+
signals (
3928+
freshness(),
3929+
)
39163930
);
39173931
39183932
SELECT col1 * col2 AS col FROM {external_table1}, {external_table2};
@@ -3924,23 +3938,47 @@ def _set_config(gateway: str, config: Config) -> None:
39243938

39253939
context = ctx.create_context(path=tmp_path, config_mutator=_set_config)
39263940

3927-
# Case 1: Model is evaluated on first insertion
3928-
_assert_model_evaluation(lambda: _run_plan(context), was_evaluated=True)
3941+
# Case 1: Model is evaluated for the first plan
3942+
prod_plan_ts, prod_plan = _assert_model_evaluation(
3943+
lambda: context.plan(auto_apply=True, no_prompts=True), was_evaluated=True
3944+
)
3945+
3946+
prod_snapshot_id = next(iter(prod_plan.context_diff.new_snapshots))
3947+
_assert_snapshot_last_altered_ts(context, prod_snapshot_id, prod_plan_ts)
39293948

39303949
# Case 2: Model is NOT evaluated on run if external models are not fresh
3931-
_assert_model_evaluation(lambda: context.run(), was_evaluated=False, day_delta=2)
3950+
_assert_model_evaluation(lambda: context.run(), was_evaluated=False, day_delta=1)
39323951

3933-
# Case 3: Model is evaluated on run if any external model is fresh
3934-
adapter.execute(f"INSERT INTO {external_table2} (col2) VALUES (3)", quote_identifiers=False)
3952+
# Case 3: Differentiate last_altered_ts between snapshots with shared version
3953+
# For instance, creating a FORWARD_ONLY change in dev (reusing the version but creating a dev preview) should not cause
3954+
# the prod snapshot's last_altered_ts to be updated when fetched from the state sync
3955+
model_path.write_text(model_path.read_text().replace("col1 * col2", "col1 + col2"))
3956+
context.load()
3957+
dev_plan_ts = now(minute_floor=False) + timedelta(days=2)
3958+
with time_machine.travel(dev_plan_ts, tick=False):
3959+
dev_plan = context.plan(
3960+
environment="dev", forward_only=True, auto_apply=True, no_prompts=True
3961+
)
3962+
3963+
context.state_sync.clear_cache()
3964+
dev_snapshot_id = next(iter(dev_plan.context_diff.new_snapshots))
3965+
_assert_snapshot_last_altered_ts(context, dev_snapshot_id, dev_plan_ts)
3966+
_assert_snapshot_last_altered_ts(context, prod_snapshot_id, prod_plan_ts)
39353967

3968+
# Case 4: Model is evaluated on run if any external model is fresh
3969+
adapter.execute(f"INSERT INTO {external_table2} (col2) VALUES (3)", quote_identifiers=False)
39363970
_assert_model_evaluation(lambda: context.run(), was_evaluated=True, day_delta=2)
39373971

3938-
# Case 4: Model is evaluated on a restatement plan even if the external model is not fresh
3972+
# Case 5: Model is evaluated if changed (case 3) even if the external model is not fresh
3973+
model_path.write_text(model_path.read_text().replace("col1 + col2", "col1 * col2 * 5"))
3974+
context.load()
39393975
_assert_model_evaluation(
3940-
lambda: _run_plan(context, restate_models=[model_name]), was_evaluated=True, day_delta=3
3976+
lambda: context.plan(auto_apply=True, no_prompts=True), was_evaluated=True, day_delta=3
39413977
)
39423978

3943-
# Case 5: Model is evaluated if changed even if the external model is not fresh
3944-
model_path.write_text(model_path.read_text().replace("col1 * col2", "col1 + col2"))
3945-
context.load()
3946-
_assert_model_evaluation(lambda: _run_plan(context), was_evaluated=True, day_delta=2)
3979+
# Case 6: Model is evaluated on a restatement plan even if the external model is not fresh
3980+
_assert_model_evaluation(
3981+
lambda: context.plan(restate_models=[model_name], auto_apply=True, no_prompts=True),
3982+
was_evaluated=True,
3983+
day_delta=4,
3984+
)

0 commit comments

Comments
 (0)