Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@
)
from sqlmesh.core.snapshot import Node

from sqlmesh.core.snapshot.definition import Intervals

ModelOrSnapshot = t.Union[str, Model, Snapshot]
NodeOrSnapshot = t.Union[str, Model, StandaloneAudit, Snapshot]

Expand Down Expand Up @@ -276,6 +278,7 @@ def __init__(
default_dialect: t.Optional[str] = None,
default_catalog: t.Optional[str] = None,
is_restatement: t.Optional[bool] = None,
parent_intervals: t.Optional[Intervals] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
):
Expand All @@ -287,6 +290,7 @@ def __init__(
self._variables = variables or {}
self._blueprint_variables = blueprint_variables or {}
self._is_restatement = is_restatement
self._parent_intervals = parent_intervals

@property
def default_dialect(self) -> t.Optional[str]:
Expand Down Expand Up @@ -315,6 +319,10 @@ def gateway(self) -> t.Optional[str]:
def is_restatement(self) -> t.Optional[bool]:
return self._is_restatement

@property
def parent_intervals(self) -> t.Optional[Intervals]:
return self._parent_intervals

def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
"""Returns a variable value."""
return self._variables.get(var_name.lower(), default)
Expand Down
11 changes: 10 additions & 1 deletion sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def batch_intervals(
)
for snapshot, intervals in merged_intervals.items()
}
snapshot_batches = {}
snapshot_batches: t.Dict[Snapshot, Intervals] = {}
all_unready_intervals: t.Dict[str, set[Interval]] = {}
for snapshot_id in dag:
if snapshot_id not in snapshot_intervals:
Expand All @@ -364,13 +364,22 @@ def batch_intervals(

adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)

parent_intervals: Intervals = []
for parent_id in snapshot.parents:
parent_snapshot, _ = snapshot_intervals.get(parent_id, (None, []))
if not parent_snapshot or parent_snapshot.is_external:
continue

parent_intervals.extend(snapshot_batches[parent_snapshot])

context = ExecutionContext(
adapter,
self.snapshots_by_name,
deployability_index,
default_dialect=adapter.dialect,
default_catalog=self.default_catalog,
is_restatement=is_restatement,
parent_intervals=parent_intervals,
)

intervals = self._check_ready_intervals(
Expand Down
49 changes: 35 additions & 14 deletions sqlmesh/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t
from sqlmesh.utils import UniqueKeyDict, registry_decorator
from sqlmesh.utils.errors import MissingSourceError

if t.TYPE_CHECKING:
from sqlmesh.core.context import ExecutionContext
Expand Down Expand Up @@ -42,7 +43,16 @@ class signal(registry_decorator):


@signal()
def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
def freshness(
batch: DatetimeRanges,
snapshot: Snapshot,
context: ExecutionContext,
) -> bool:
"""
Implements model freshness as a signal, i.e it considers this model to be fresh if:
- Any upstream SQLMesh model has available intervals to compute i.e is fresh
- Any upstream external model has been altered since the last time the model was evaluated
"""
adapter = context.engine_adapter
if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
return True
Expand All @@ -54,24 +64,35 @@ def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionConte
if deployability_index.is_deployable(snapshot)
else snapshot.dev_last_altered_ts
)

if not last_altered_ts:
return True

parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
p.is_external for p in parent_snapshots
):
# The mismatch can happen if e.g an external model is not registered in the project

upstream_parent_snapshots = {p for p in parent_snapshots if not p.is_external}
external_parents = snapshot.node.depends_on - {p.name for p in upstream_parent_snapshots}

if context.parent_intervals:
# At least one upstream sqlmesh model has intervals to compute (i.e is fresh),
# so the current model is considered fresh too
return True

# Finding new data means that the upstream depedencies have been altered
# since the last time the model was evaluated
upstream_dep_has_new_data = any(
upstream_last_altered_ts > last_altered_ts
for upstream_last_altered_ts in adapter.get_table_last_modified_ts(
[p.name for p in parent_snapshots]
if external_parents:
external_last_altered_timestamps = adapter.get_table_last_modified_ts(
list(external_parents)
)

if len(external_last_altered_timestamps) != len(external_parents):
raise MissingSourceError(
f"Expected {len(external_parents)} sources to be present, but got {len(external_last_altered_timestamps)}."
)

# Finding new data means that the upstream depedencies have been altered
# since the last time the model was evaluated
return any(
external_last_altered_ts > last_altered_ts
for external_last_altered_ts in external_last_altered_timestamps
)
)

# Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
return upstream_dep_has_new_data
return False
Loading