@@ -325,6 +325,8 @@ def __init__(
325325
326326 self .path , self .config = t .cast (t .Tuple [Path , C ], next (iter (self .configs .items ())))
327327
328+ self ._all_dialects : t .Set [str ] = {self .config .dialect or "" }
329+
328330 # This allows overriding the default dialect's normalization strategy, so for example
329331 # one can do `dialect="duckdb,normalization_strategy=lowercase"` and this will be
330332 # applied to the DuckDB dialect globally
@@ -435,6 +437,9 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
435437 self .path ,
436438 )
437439
440+ if model .dialect :
441+ self ._all_dialects .add (model .dialect )
442+
438443 model .validate_definition ()
439444
440445 return model
@@ -517,6 +522,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
517522 f"Models and Standalone audits cannot have the same name: { duplicates } "
518523 )
519524
525+ self ._all_dialects = {m .dialect for m in self ._models .values () if m .dialect } | {
526+ self .default_dialect or ""
527+ }
528+
520529 analytics .collector .on_project_loaded (
521530 project_type = (
522531 c .DBT if type (self ._loader ).__name__ .lower ().startswith (c .DBT ) else c .NATIVE
@@ -624,21 +633,24 @@ def get_model(
624633 The expected model.
625634 """
626635 if isinstance (model_or_snapshot , str ):
627- normalized_name = normalize_model_name (
628- model_or_snapshot ,
629- dialect = self .default_dialect ,
630- default_catalog = self .default_catalog ,
631- )
632- model = self ._models .get (normalized_name )
636+ # We should try all dialects referenced in the project for cases when models use mixed dialects.
637+ for dialect in self ._all_dialects :
638+ normalized_name = normalize_model_name (
639+ model_or_snapshot ,
640+ dialect = dialect ,
641+ default_catalog = self .default_catalog ,
642+ )
643+ if normalized_name in self ._models :
644+ return self ._models [normalized_name ]
633645 elif isinstance (model_or_snapshot , Snapshot ):
634- model = model_or_snapshot .model
646+ return model_or_snapshot .model
635647 else :
636- model = model_or_snapshot
648+ return model_or_snapshot
637649
638- if raise_if_missing and not model :
650+ if raise_if_missing :
639651 raise SQLMeshError (f"Cannot find model for '{ model_or_snapshot } '" )
640652
641- return model
653+ return None
642654
643655 @t .overload
644656 def get_snapshot (self , node_or_snapshot : NodeOrSnapshot ) -> t .Optional [Snapshot ]: ...
0 commit comments