Skip to content

Commit 0b65947

Browse files
committed
Refactor parent signals
1 parent cfa86dd commit 0b65947

File tree

4 files changed

+27
-28
lines changed

4 files changed

+27
-28
lines changed

sqlmesh/core/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@
153153
)
154154
from sqlmesh.core.snapshot import Node
155155

156+
from sqlmesh.core.snapshot.definition import Intervals
157+
156158
ModelOrSnapshot = t.Union[str, Model, Snapshot]
157159
NodeOrSnapshot = t.Union[str, Model, StandaloneAudit, Snapshot]
158160

@@ -275,6 +277,7 @@ def __init__(
275277
default_dialect: t.Optional[str] = None,
276278
default_catalog: t.Optional[str] = None,
277279
is_restatement: t.Optional[bool] = None,
280+
parent_intervals: t.Optional[t.List[Intervals]] = None,
278281
variables: t.Optional[t.Dict[str, t.Any]] = None,
279282
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
280283
):
@@ -286,6 +289,7 @@ def __init__(
286289
self._variables = variables or {}
287290
self._blueprint_variables = blueprint_variables or {}
288291
self._is_restatement = is_restatement
292+
self._parent_intervals = parent_intervals
289293

290294
@property
291295
def default_dialect(self) -> t.Optional[str]:
@@ -314,6 +318,10 @@ def gateway(self) -> t.Optional[str]:
314318
def is_restatement(self) -> t.Optional[bool]:
315319
return self._is_restatement
316320

321+
@property
322+
def parent_intervals(self) -> t.Optional[t.List[Intervals]]:
323+
return self._parent_intervals
324+
317325
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
318326
"""Returns a variable value."""
319327
return self._variables.get(var_name.lower(), default)

sqlmesh/core/scheduler.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -364,28 +364,28 @@ def batch_intervals(
364364

365365
adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)
366366

367+
parent_intervals = []
368+
for parent in snapshot.parents:
369+
if parent.snapshot_id not in snapshot_intervals:
370+
continue
371+
_, p_intervals = snapshot_intervals[parent.snapshot_id]
372+
parent_intervals.append(p_intervals)
373+
367374
context = ExecutionContext(
368375
adapter,
369376
self.snapshots_by_name,
370377
deployability_index,
371378
default_dialect=adapter.dialect,
372379
default_catalog=self.default_catalog,
373380
is_restatement=is_restatement,
381+
parent_intervals=parent_intervals,
374382
)
375383

376-
parent_intervals = []
377-
for parent in snapshot.parents:
378-
if parent.snapshot_id not in snapshot_intervals:
379-
continue
380-
_, p_intervals = snapshot_intervals[parent.snapshot_id]
381-
parent_intervals.append(p_intervals)
382-
383384
intervals = self._check_ready_intervals(
384385
snapshot,
385386
intervals,
386387
context,
387388
environment_naming_info,
388-
parent_intervals=parent_intervals,
389389
)
390390
unready -= set(intervals)
391391

@@ -931,7 +931,6 @@ def _check_ready_intervals(
931931
intervals: Intervals,
932932
context: ExecutionContext,
933933
environment_naming_info: EnvironmentNamingInfo,
934-
parent_intervals: t.Optional[t.List[Intervals]] = None,
935934
) -> Intervals:
936935
"""Checks if the intervals are ready for evaluation for the given snapshot.
937936
@@ -974,7 +973,6 @@ def _check_ready_intervals(
974973
dialect=snapshot.model.dialect,
975974
path=snapshot.model._path,
976975
snapshot=snapshot,
977-
parent_intervals=parent_intervals,
978976
kwargs=kwargs,
979977
)
980978
except SQLMeshError as e:

sqlmesh/core/signal.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sqlmesh.core.snapshot.definition import Snapshot
1010
from sqlmesh.utils.date import DatetimeRanges
1111
from sqlmesh.core.snapshot.definition import DeployabilityIndex
12-
from sqlmesh.core.snapshot.definition import Intervals
1312

1413

1514
class signal(registry_decorator):
@@ -48,7 +47,6 @@ def freshness(
4847
batch: DatetimeRanges,
4948
snapshot: Snapshot,
5049
context: ExecutionContext,
51-
parent_intervals: t.Optional[t.List[Intervals]] = None,
5250
) -> bool:
5351
"""
5452
Implements model freshness as a signal, i.e it considers this model to be fresh if:
@@ -71,33 +69,30 @@ def freshness(
7169
return True
7270

7371
parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
74-
if len(parent_snapshots) != len(snapshot.node.depends_on):
75-
# The mismatch can happen if e.g an external model is not registered in the project
76-
return True
7772

78-
external_parent_snapshots = {p for p in parent_snapshots if p.is_external}
79-
upstream_parent_snapshots = parent_snapshots - external_parent_snapshots
73+
upstream_parent_snapshots = {p for p in parent_snapshots if not p.is_external}
74+
external_parents = snapshot.node.depends_on - {p.name for p in upstream_parent_snapshots}
8075

81-
if upstream_parent_snapshots and parent_intervals:
82-
# At least one upstream sqlmesh model has intervals to compute (i.e is not fresh),
83-
# so the current model should be considered fresh
76+
if context.parent_intervals:
77+
# At least one upstream sqlmesh model has intervals to compute (i.e is fresh),
78+
# so the current model is considered fresh too
8479
return True
8580

86-
if external_parent_snapshots:
81+
if external_parents:
8782
external_last_altered_timestamps = adapter.get_table_last_modified_ts(
88-
[sp.name for sp in external_parent_snapshots]
83+
list(external_parents)
8984
)
9085

91-
if len(external_last_altered_timestamps) != len(external_parent_snapshots):
86+
if len(external_last_altered_timestamps) != len(external_parents):
9287
raise MissingSourceError(
93-
f"Expected {len(external_parent_snapshots)} sources to be present, but got {len(external_last_altered_timestamps)}."
88+
f"Expected {len(external_parents)} sources to be present, but got {len(external_last_altered_timestamps)}."
9489
)
9590

9691
# Finding new data means that the upstream depedencies have been altered
9792
# since the last time the model was evaluated
9893
return any(
99-
upstream_last_altered_ts > last_altered_ts
100-
for upstream_last_altered_ts in external_last_altered_timestamps
94+
external_last_altered_ts > last_altered_ts
95+
for external_last_altered_ts in external_last_altered_timestamps
10196
)
10297

10398
return False

sqlmesh/core/snapshot/definition.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,7 +2451,6 @@ def check_ready_intervals(
24512451
dialect: DialectType = None,
24522452
path: t.Optional[Path] = None,
24532453
snapshot: t.Optional[Snapshot] = None,
2454-
parent_intervals: t.Optional[t.List[Intervals]] = None,
24552454
kwargs: t.Optional[t.Dict] = None,
24562455
) -> Intervals:
24572456
checked_intervals: Intervals = []
@@ -2468,7 +2467,6 @@ def check_ready_intervals(
24682467
provided_kwargs=(kwargs or {}),
24692468
context=context,
24702469
snapshot=snapshot,
2471-
parent_intervals=parent_intervals,
24722470
)
24732471
except Exception as ex:
24742472
raise SignalEvalError(format_evaluated_code_exception(ex, python_env))

0 commit comments

Comments
 (0)