Skip to content

Commit 0418cac

Browse files
authored
fix: load multi repo models earlier to ensure schema is correct (#3778)
1 parent 0b4ee9e commit 0418cac

File tree

8 files changed

+78
-50
lines changed

8 files changed

+78
-50
lines changed

examples/multi/repo_1/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ gateways:
44
local:
55
connection:
66
type: duckdb
7-
database: db.db
7+
database: db.duckdb
88

99
memory:
1010
connection:

examples/multi/repo_2/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ gateways:
44
local:
55
connection:
66
type: duckdb
7-
database: db.db
7+
database: db.duckdb
88

99
memory:
1010
connection:
@@ -13,4 +13,4 @@ gateways:
1313
default_gateway: local
1414

1515
model_defaults:
16-
dialect: 'duckdb'
16+
dialect: 'duckdb'

examples/multi/repo_2/models/e.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
MODEL (
2+
name silver.e
3+
);
4+
5+
SELECT
6+
* EXCEPT(dup)
7+
FROM bronze.a

sqlmesh/core/context.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
self.configs = (
335335
config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths)
336336
)
337+
self._projects = {config.project for config in self.configs.values()}
337338
self.dag: DAG[str] = DAG()
338339
self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
339340
self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits")
@@ -553,7 +554,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
553554
"""Load all files in the context's path."""
554555
load_start_ts = time.perf_counter()
555556

556-
projects = [loader.load() for loader in self._loaders]
557+
loaded_projects = [loader.load() for loader in self._loaders]
557558

558559
self.dag = DAG()
559560
self._standalone_audits.clear()
@@ -564,7 +565,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
564565
self._requirements.clear()
565566
self._excluded_requirements.clear()
566567

567-
for project in projects:
568+
for project in loaded_projects:
568569
self._jinja_macros = self._jinja_macros.merge(project.jinja_macros)
569570
self._macros.update(project.macros)
570571
self._models.update(project.models)
@@ -574,15 +575,39 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
574575
self._requirements.update(project.requirements)
575576
self._excluded_requirements.update(project.excluded_requirements)
576577

578+
uncached = set()
579+
580+
if any(self._projects):
581+
prod = self.state_reader.get_environment(c.PROD)
582+
583+
if prod:
584+
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values():
585+
if snapshot.node.project in self._projects:
586+
uncached.add(snapshot.name)
587+
else:
588+
store = self._standalone_audits if snapshot.is_audit else self._models
589+
store[snapshot.name] = snapshot.node # type: ignore
590+
577591
for model in self._models.values():
578592
self.dag.add(model.fqn, model.depends_on)
579593

580-
# This topologically sorts the DAG & caches the result in-memory for later;
581-
# we do it here to detect any cycles as early as possible and fail if needed
582-
self.dag.sorted
583-
584594
if update_schemas:
595+
for fqn in self.dag:
596+
model = self._models.get(fqn) # type: ignore
597+
598+
if not model or fqn in uncached:
599+
continue
600+
601+
# make a copy of remote models that depend on local models or in the downstream chain
602+
# without this, a SELECT * FROM local will not propogate properly because the downstream
603+
# model will get mutated (schema changes) but the object is the same as the remote cache
604+
if any(dep in uncached for dep in model.depends_on):
605+
uncached.add(fqn)
606+
self._models.update({fqn: model.copy(update={"mapping_schema": {}})})
607+
continue
608+
585609
update_model_schemas(self.dag, models=self._models, context_path=self.path)
610+
586611
for model in self.models.values():
587612
# The model definition can be validated correctly only after the schema is set.
588613
model.validate_definition()
@@ -2105,63 +2130,39 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
21052130
def _snapshots(
21062131
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
21072132
) -> t.Dict[str, Snapshot]:
2108-
projects = {config.project for config in self.configs.values()}
2109-
2110-
if any(projects):
2111-
prod = self.state_reader.get_environment(c.PROD)
2112-
remote_snapshots = (
2113-
{
2114-
snapshot.name: snapshot
2115-
for snapshot in self.state_reader.get_snapshots(prod.snapshots).values()
2116-
}
2117-
if prod
2118-
else {}
2119-
)
2120-
else:
2121-
remote_snapshots = {}
2122-
2123-
local_nodes = {**(models_override or self._models), **self._standalone_audits}
2124-
nodes = local_nodes.copy()
2125-
2126-
for name, snapshot in remote_snapshots.items():
2127-
if name not in nodes and snapshot.node.project not in projects:
2128-
nodes[name] = snapshot.node
2129-
21302133
def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]:
21312134
snapshots: t.Dict[str, Snapshot] = {}
21322135
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}
21332136

21342137
for node in nodes.values():
2135-
if node.fqn not in local_nodes and node.fqn in remote_snapshots:
2136-
ttl = remote_snapshots[node.fqn].ttl
2137-
else:
2138-
config = self.config_for_node(node)
2139-
ttl = config.snapshot_ttl
2138+
kwargs = {}
2139+
if node.project in self._projects:
2140+
kwargs["ttl"] = self.config_for_node(node).snapshot_ttl
21402141

21412142
snapshot = Snapshot.from_node(
21422143
node,
21432144
nodes=nodes,
21442145
cache=fingerprint_cache,
2145-
ttl=ttl,
2146-
config=self.config_for_node(node),
2146+
**kwargs,
21472147
)
21482148
snapshots[snapshot.name] = snapshot
21492149
return snapshots
21502150

2151+
nodes = {**(models_override or self._models), **self._standalone_audits}
21512152
snapshots = _nodes_to_snapshots(nodes)
21522153
stored_snapshots = self.state_reader.get_snapshots(snapshots.values())
21532154

21542155
unrestorable_snapshots = {
21552156
snapshot
21562157
for snapshot in stored_snapshots.values()
2157-
if snapshot.name in local_nodes and snapshot.unrestorable
2158+
if snapshot.name in nodes and snapshot.unrestorable
21582159
}
21592160
if unrestorable_snapshots:
21602161
for snapshot in unrestorable_snapshots:
21612162
logger.info(
21622163
"Found a unrestorable snapshot %s. Restamping the model...", snapshot.name
21632164
)
2164-
node = local_nodes[snapshot.name]
2165+
node = nodes[snapshot.name]
21652166
nodes[snapshot.name] = node.copy(
21662167
update={"stamp": f"revert to {snapshot.identifier}"}
21672168
)

sqlmesh/core/model/definition.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -687,13 +687,23 @@ def text_diff(self, other: Node, rendered: bool = False) -> str:
687687
f"Cannot diff model '{self.name} against a non-model node '{other.name}'"
688688
)
689689

690-
return d.text_diff(
690+
text_diff = d.text_diff(
691691
self.render_definition(render_query=rendered),
692692
other.render_definition(render_query=rendered),
693693
self.dialect,
694694
other.dialect,
695695
).strip()
696696

697+
if not text_diff and not rendered:
698+
text_diff = d.text_diff(
699+
self.render_definition(render_query=True),
700+
other.render_definition(render_query=True),
701+
self.dialect,
702+
other.dialect,
703+
).strip()
704+
705+
return text_diff
706+
697707
def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT) -> None:
698708
"""Sets the default time format for a model.
699709
@@ -1255,8 +1265,12 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
12551265
if query is None:
12561266
return None
12571267

1268+
unknown = exp.DataType.build("unknown")
1269+
12581270
self._columns_to_types = {
1259-
select.output_name: select.type or exp.DataType.build("unknown")
1271+
# copy data type because it is used in the engine to build CTAS and other queries
1272+
# this can change the parent which will mess up the diffing algo
1273+
select.output_name: (select.type or unknown).copy()
12601274
for select in query.selects
12611275
}
12621276

@@ -1351,9 +1365,16 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
13511365
# Can't determine if there's a breaking change if we can't render the query.
13521366
return None
13531367

1354-
edits = diff(
1355-
previous_query, this_query, matchings=[(previous_query, this_query)], delta_only=True
1356-
)
1368+
if previous_query is this_query:
1369+
edits = []
1370+
else:
1371+
edits = diff(
1372+
previous_query,
1373+
this_query,
1374+
matchings=[(previous_query, this_query)],
1375+
delta_only=True,
1376+
copy=False,
1377+
)
13571378
inserted_expressions = {e.expression for e in edits if isinstance(e, Insert)}
13581379

13591380
for edit in edits:

sqlmesh/core/snapshot/definition.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
if t.TYPE_CHECKING:
4747
from sqlglot.dialects.dialect import DialectType
4848
from sqlmesh.core.environment import EnvironmentNamingInfo
49-
from sqlmesh.core.config import Config
5049

5150
Interval = t.Tuple[int, int]
5251
Intervals = t.List[Interval]
@@ -596,7 +595,6 @@ def from_node(
596595
ttl: str = c.DEFAULT_SNAPSHOT_TTL,
597596
version: t.Optional[str] = None,
598597
cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None,
599-
config: t.Optional[Config] = None,
600598
) -> Snapshot:
601599
"""Creates a new snapshot for a node.
602600

tests/core/test_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@ def test_wildcard(copy_to_temp_path: t.Callable):
11191119
parent_path = copy_to_temp_path("examples/multi")[0]
11201120

11211121
context = Context(paths=f"{parent_path}/*")
1122-
assert len(context.models) == 4
1122+
assert len(context.models) == 5
11231123

11241124

11251125
def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path):

tests/core/test_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,7 +3837,7 @@ def test_multi(mocker):
38373837
)
38383838
context._new_state_sync().reset(default_catalog=context.default_catalog)
38393839
plan = context.plan_builder().build()
3840-
assert len(plan.new_snapshots) == 4
3840+
assert len(plan.new_snapshots) == 5
38413841
context.apply(plan)
38423842

38433843
adapter = context.engine_adapter
@@ -3856,12 +3856,13 @@ def test_multi(mocker):
38563856
assert set(snapshot.name for snapshot in plan.directly_modified) == {
38573857
'"memory"."bronze"."a"',
38583858
'"memory"."bronze"."b"',
3859+
'"memory"."silver"."e"',
38593860
}
38603861
assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [
38613862
'"memory"."silver"."c"',
38623863
'"memory"."silver"."d"',
38633864
]
3864-
assert len(plan.missing_intervals) == 2
3865+
assert len(plan.missing_intervals) == 3
38653866
context.apply(plan)
38663867
validate_apply_basics(context, c.PROD, plan.snapshots.values())
38673868

0 commit comments

Comments
 (0)