Skip to content

Commit 273e913

Browse files
authored
Fix: Check all dialects used in a project when retrieving a model by its name (#2801)
1 parent e0f5dcc commit 273e913

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

sqlmesh/core/context.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ def __init__(
325325

326326
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
327327

328+
self._all_dialects: t.Set[str] = {self.config.dialect or ""}
329+
328330
# This allows overriding the default dialect's normalization strategy, so for example
329331
# one can do `dialect="duckdb,normalization_strategy=lowercase"` and this will be
330332
# applied to the DuckDB dialect globally
@@ -435,6 +437,9 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
435437
self.path,
436438
)
437439

440+
if model.dialect:
441+
self._all_dialects.add(model.dialect)
442+
438443
model.validate_definition()
439444

440445
return model
@@ -517,6 +522,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
517522
f"Models and Standalone audits cannot have the same name: {duplicates}"
518523
)
519524

525+
self._all_dialects = {m.dialect for m in self._models.values() if m.dialect} | {
526+
self.default_dialect or ""
527+
}
528+
520529
analytics.collector.on_project_loaded(
521530
project_type=(
522531
c.DBT if type(self._loader).__name__.lower().startswith(c.DBT) else c.NATIVE
@@ -624,21 +633,24 @@ def get_model(
624633
The expected model.
625634
"""
626635
if isinstance(model_or_snapshot, str):
627-
normalized_name = normalize_model_name(
628-
model_or_snapshot,
629-
dialect=self.default_dialect,
630-
default_catalog=self.default_catalog,
631-
)
632-
model = self._models.get(normalized_name)
636+
# We should try all dialects referenced in the project for cases when models use mixed dialects.
637+
for dialect in self._all_dialects:
638+
normalized_name = normalize_model_name(
639+
model_or_snapshot,
640+
dialect=dialect,
641+
default_catalog=self.default_catalog,
642+
)
643+
if normalized_name in self._models:
644+
return self._models[normalized_name]
633645
elif isinstance(model_or_snapshot, Snapshot):
634-
model = model_or_snapshot.model
646+
return model_or_snapshot.model
635647
else:
636-
model = model_or_snapshot
648+
return model_or_snapshot
637649

638-
if raise_if_missing and not model:
650+
if raise_if_missing:
639651
raise SQLMeshError(f"Cannot find model for '{model_or_snapshot}'")
640652

641-
return model
653+
return None
642654

643655
@t.overload
644656
def get_snapshot(self, node_or_snapshot: NodeOrSnapshot) -> t.Optional[Snapshot]: ...

tests/core/test_context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,25 @@ def test_disabled_model(copy_to_temp_path):
765765
assert not context.get_model("sushi.disabled")
766766

767767

768+
def test_get_model_mixed_dialects(copy_to_temp_path):
769+
path = copy_to_temp_path("examples/sushi")
770+
771+
context = Context(paths=path)
772+
expression = d.parse(
773+
"""
774+
MODEL(
775+
name sushi.snowflake_dialect,
776+
dialect snowflake,
777+
);
778+
779+
SELECT 1"""
780+
)
781+
model = load_sql_based_model(expression, default_catalog=context.default_catalog)
782+
context.upsert_model(model)
783+
784+
assert context.get_model("sushi.snowflake_dialect") == model
785+
786+
768787
def test_override_dialect_normalization_strategy():
769788
config = Config(
770789
model_defaults=ModelDefaultsConfig(dialect="duckdb,normalization_strategy=lowercase")

0 commit comments

Comments
 (0)