@@ -334,6 +334,7 @@ def __init__(
334334 self .configs = (
335335 config if isinstance (config , dict ) else load_configs (config , self .CONFIG_TYPE , paths )
336336 )
337+ self ._projects = {config .project for config in self .configs .values ()}
337338 self .dag : DAG [str ] = DAG ()
338339 self ._models : UniqueKeyDict [str , Model ] = UniqueKeyDict ("models" )
339340 self ._audits : UniqueKeyDict [str , ModelAudit ] = UniqueKeyDict ("audits" )
@@ -553,7 +554,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
553554 """Load all files in the context's path."""
554555 load_start_ts = time .perf_counter ()
555556
556- projects = [loader .load () for loader in self ._loaders ]
557+ loaded_projects = [loader .load () for loader in self ._loaders ]
557558
558559 self .dag = DAG ()
559560 self ._standalone_audits .clear ()
@@ -564,7 +565,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
564565 self ._requirements .clear ()
565566 self ._excluded_requirements .clear ()
566567
567- for project in projects :
568+ for project in loaded_projects :
568569 self ._jinja_macros = self ._jinja_macros .merge (project .jinja_macros )
569570 self ._macros .update (project .macros )
570571 self ._models .update (project .models )
@@ -574,15 +575,39 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
574575 self ._requirements .update (project .requirements )
575576 self ._excluded_requirements .update (project .excluded_requirements )
576577
578+ uncached = set ()
579+
580+ if any (self ._projects ):
581+ prod = self .state_reader .get_environment (c .PROD )
582+
583+ if prod :
584+ for snapshot in self .state_reader .get_snapshots (prod .snapshots ).values ():
585+ if snapshot .node .project in self ._projects :
586+ uncached .add (snapshot .name )
587+ else :
588+ store = self ._standalone_audits if snapshot .is_audit else self ._models
589+ store [snapshot .name ] = snapshot .node # type: ignore
590+
577591 for model in self ._models .values ():
578592 self .dag .add (model .fqn , model .depends_on )
579593
580- # This topologically sorts the DAG & caches the result in-memory for later;
581- # we do it here to detect any cycles as early as possible and fail if needed
582- self .dag .sorted
583-
584594 if update_schemas :
595+ for fqn in self .dag :
596+ model = self ._models .get (fqn ) # type: ignore
597+
598+ if not model or fqn in uncached :
599+ continue
600+
601+ # make a copy of remote models that depend on local models or in the downstream chain
602+ # without this, a SELECT * FROM local will not propogate properly because the downstream
603+ # model will get mutated (schema changes) but the object is the same as the remote cache
604+ if any (dep in uncached for dep in model .depends_on ):
605+ uncached .add (fqn )
606+ self ._models .update ({fqn : model .copy (update = {"mapping_schema" : {}})})
607+ continue
608+
585609 update_model_schemas (self .dag , models = self ._models , context_path = self .path )
610+
586611 for model in self .models .values ():
587612 # The model definition can be validated correctly only after the schema is set.
588613 model .validate_definition ()
@@ -2105,63 +2130,39 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
21052130 def _snapshots (
21062131 self , models_override : t .Optional [UniqueKeyDict [str , Model ]] = None
21072132 ) -> t .Dict [str , Snapshot ]:
2108- projects = {config .project for config in self .configs .values ()}
2109-
2110- if any (projects ):
2111- prod = self .state_reader .get_environment (c .PROD )
2112- remote_snapshots = (
2113- {
2114- snapshot .name : snapshot
2115- for snapshot in self .state_reader .get_snapshots (prod .snapshots ).values ()
2116- }
2117- if prod
2118- else {}
2119- )
2120- else :
2121- remote_snapshots = {}
2122-
2123- local_nodes = {** (models_override or self ._models ), ** self ._standalone_audits }
2124- nodes = local_nodes .copy ()
2125-
2126- for name , snapshot in remote_snapshots .items ():
2127- if name not in nodes and snapshot .node .project not in projects :
2128- nodes [name ] = snapshot .node
2129-
21302133 def _nodes_to_snapshots (nodes : t .Dict [str , Node ]) -> t .Dict [str , Snapshot ]:
21312134 snapshots : t .Dict [str , Snapshot ] = {}
21322135 fingerprint_cache : t .Dict [str , SnapshotFingerprint ] = {}
21332136
21342137 for node in nodes .values ():
2135- if node .fqn not in local_nodes and node .fqn in remote_snapshots :
2136- ttl = remote_snapshots [node .fqn ].ttl
2137- else :
2138- config = self .config_for_node (node )
2139- ttl = config .snapshot_ttl
2138+ kwargs = {}
2139+ if node .project in self ._projects :
2140+ kwargs ["ttl" ] = self .config_for_node (node ).snapshot_ttl
21402141
21412142 snapshot = Snapshot .from_node (
21422143 node ,
21432144 nodes = nodes ,
21442145 cache = fingerprint_cache ,
2145- ttl = ttl ,
2146- config = self .config_for_node (node ),
2146+ ** kwargs ,
21472147 )
21482148 snapshots [snapshot .name ] = snapshot
21492149 return snapshots
21502150
2151+ nodes = {** (models_override or self ._models ), ** self ._standalone_audits }
21512152 snapshots = _nodes_to_snapshots (nodes )
21522153 stored_snapshots = self .state_reader .get_snapshots (snapshots .values ())
21532154
21542155 unrestorable_snapshots = {
21552156 snapshot
21562157 for snapshot in stored_snapshots .values ()
2157- if snapshot .name in local_nodes and snapshot .unrestorable
2158+ if snapshot .name in nodes and snapshot .unrestorable
21582159 }
21592160 if unrestorable_snapshots :
21602161 for snapshot in unrestorable_snapshots :
21612162 logger .info (
21622163 "Found a unrestorable snapshot %s. Restamping the model..." , snapshot .name
21632164 )
2164- node = local_nodes [snapshot .name ]
2165+ node = nodes [snapshot .name ]
21652166 nodes [snapshot .name ] = node .copy (
21662167 update = {"stamp" : f"revert to { snapshot .identifier } " }
21672168 )
0 commit comments