Skip to content

Commit 47b5000

Browse files
authored
Fix: Invalidate the cached model's rendered query after updating the mapping schema (#2884)
1 parent acb2a01 commit 47b5000

File tree

5 files changed

+113
-29
lines changed

5 files changed

+113
-29
lines changed

sqlmesh/core/model/cache.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class SqlModelCacheEntry(PydanticModel):
1616
model: SqlModel
17-
rendered_query: t.Optional[exp.Expression] = None
17+
full_depends_on: t.Set[str]
1818

1919

2020
class ModelCache:
@@ -46,21 +46,21 @@ def get_or_load(self, name: str, entry_id: str = "", *, loader: t.Callable[[], M
4646
cache_entry = self._file_cache.get(name, entry_id)
4747
if cache_entry:
4848
model = cache_entry.model
49-
model._query_renderer.update_cache(cache_entry.rendered_query, optimized=False)
49+
model._full_depends_on = cache_entry.full_depends_on
5050
return model
5151

5252
loaded_model = loader()
5353
if isinstance(loaded_model, SqlModel):
5454
new_entry = SqlModelCacheEntry(
55-
model=loaded_model, rendered_query=loaded_model.render_query(optimize=False)
55+
model=loaded_model, full_depends_on=loaded_model.full_depends_on
5656
)
5757
self._file_cache.put(name, entry_id, value=new_entry)
5858

5959
return loaded_model
6060

6161

6262
class OptimizedQueryCacheEntry(PydanticModel):
63-
optimized_rendered_query: exp.Expression
63+
optimized_rendered_query: t.Optional[exp.Expression]
6464

6565

6666
class OptimizedQueryCache:
@@ -85,23 +85,29 @@ def with_optimized_query(self, model: Model) -> bool:
8585
if not isinstance(model, SqlModel):
8686
return False
8787

88-
unoptimized_query = model.render_query(optimize=False)
89-
if unoptimized_query is None:
90-
return False
91-
9288
hash_data = _mapping_schema_hash_data(model.mapping_schema)
93-
hash_data.append(gen(unoptimized_query))
89+
hash_data.append(gen(model.query))
90+
hash_data.append(str([(k, v) for k, v in model.sorted_python_env]))
91+
hash_data.extend(model.jinja_macros.data_hash_values)
92+
9493
name = f"{model.name}_{crc32(hash_data)}"
9594
cache_entry = self._file_cache.get(name)
9695

9796
if cache_entry:
98-
model._query_renderer.update_cache(cache_entry.optimized_rendered_query, optimized=True)
97+
if cache_entry.optimized_rendered_query:
98+
model._query_renderer.update_cache(
99+
cache_entry.optimized_rendered_query, optimized=True
100+
)
101+
else:
102+
# If the optimized rendered query is None, then there are likely adapter calls in the query
103+
# that prevent us from rendering it at load time. This means that we can safely set the
104+
# unoptimized cache to None as well to prevent attempts to render it downstream.
105+
model._query_renderer.update_cache(None, optimized=False)
99106
return True
100107

101-
optimized_query = model.render_query(optimize=True)
102-
if optimized_query is not None:
103-
new_entry = OptimizedQueryCacheEntry(optimized_rendered_query=optimized_query)
104-
self._file_cache.put(name, value=new_entry)
108+
optimized_query = model.render_query()
109+
new_entry = OptimizedQueryCacheEntry(optimized_rendered_query=optimized_query)
110+
self._file_cache.put(name, value=new_entry)
105111

106112
return False
107113

sqlmesh/core/model/definition.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ class _Model(ModelMeta, frozen=True):
115115
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
116116
mapping_schema: t.Dict[str, t.Any] = {}
117117

118+
_full_depends_on: t.Optional[t.Set[str]] = None
119+
118120
_expressions_validator = expression_validator
119121

120122
def render(
@@ -560,7 +562,7 @@ def depends_on(self) -> t.Set[str]:
560562
Returns:
561563
A list of all the upstream table names.
562564
"""
563-
return self._full_depends_on - {self.fqn}
565+
return self.full_depends_on - {self.fqn}
564566

565567
@property
566568
def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:
@@ -624,7 +626,7 @@ def is_seed(self) -> bool:
624626

625627
@cached_property
626628
def depends_on_self(self) -> bool:
627-
return self.fqn in self._full_depends_on
629+
return self.fqn in self.full_depends_on
628630

629631
@property
630632
def forward_only(self) -> bool:
@@ -834,17 +836,19 @@ def _additional_metadata(self) -> t.List[str]:
834836

835837
return additional_metadata
836838

837-
@cached_property
838-
def _full_depends_on(self) -> t.Set[str]:
839-
depends_on = self.depends_on_ or set()
840-
841-
query = self.render_query(optimize=False)
842-
if query is not None:
843-
depends_on |= d.find_tables(
844-
query, default_catalog=self.default_catalog, dialect=self.dialect
845-
)
839+
@property
840+
def full_depends_on(self) -> t.Set[str]:
841+
if not self._full_depends_on:
842+
depends_on = self.depends_on_ or set()
843+
844+
query = self.render_query(optimize=False)
845+
if query is not None:
846+
depends_on |= d.find_tables(
847+
query, default_catalog=self.default_catalog, dialect=self.dialect
848+
)
849+
self._full_depends_on = depends_on
846850

847-
return depends_on
851+
return self._full_depends_on
848852

849853

850854
class _SqlBasedModel(_Model):
@@ -1092,7 +1096,7 @@ def column_descriptions(self) -> t.Dict[str, str]:
10921096
if self.column_descriptions_ is not None:
10931097
return self.column_descriptions_
10941098

1095-
query = self.render_query(optimize=False)
1099+
query = self.render_query()
10961100
if query is None:
10971101
return {}
10981102

sqlmesh/utils/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def put(self, name: str, entry_id: str = "", *, value: T) -> None:
119119
raise SQLMeshError(f"Cache path '{self._path}' is not a directory.")
120120

121121
with gzip.open(self._cache_entry_path(name, entry_id), "wb", compresslevel=1) as fd:
122-
pickle.dump(value.dict(), fd)
122+
pickle.dump(value.dict(exclude_none=False), fd)
123123

124124
def _cache_entry_path(self, name: str, entry_id: str = "") -> Path:
125125
entry_file_name = "__".join(p for p in (self._cache_version, name, entry_id) if p)

tests/core/test_integration.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,23 @@
66

77
import pandas as pd
88
import pytest
9+
from pathlib import Path
910
from freezegun import freeze_time
1011
from pytest_mock.plugin import MockerFixture
1112
from sqlglot import exp
1213
from sqlglot.expressions import DataType
1314

1415
from sqlmesh import CustomMaterialization
16+
from sqlmesh.cli.example_project import init_example_project
1517
from sqlmesh.core import constants as c
1618
from sqlmesh.core import dialect as d
17-
from sqlmesh.core.config import AutoCategorizationMode
19+
from sqlmesh.core.config import (
20+
AutoCategorizationMode,
21+
Config,
22+
GatewayConfig,
23+
ModelDefaultsConfig,
24+
DuckDBConnectionConfig,
25+
)
1826
from sqlmesh.core.console import Console
1927
from sqlmesh.core.context import Context
2028
from sqlmesh.core.engine_adapter import EngineAdapter
@@ -1391,6 +1399,43 @@ def test_restatement_plan_ignores_changes(init_and_plan_context: t.Callable):
13911399
context.apply(plan)
13921400

13931401

1402+
def test_plan_twice_with_star_macro_yields_no_diff(tmp_path: Path):
1403+
init_example_project(tmp_path, dialect="duckdb")
1404+
1405+
star_model_definition = """
1406+
MODEL (
1407+
name sqlmesh_example.star_model,
1408+
kind FULL
1409+
);
1410+
1411+
SELECT @STAR(sqlmesh_example.full_model) FROM sqlmesh_example.full_model
1412+
"""
1413+
1414+
star_model_path = tmp_path / "models" / "star_model.sql"
1415+
star_model_path.write_text(star_model_definition)
1416+
1417+
db_path = str(tmp_path / "db.db")
1418+
config = Config(
1419+
gateways={"main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path))},
1420+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
1421+
)
1422+
context = Context(paths=tmp_path, config=config)
1423+
context.plan(auto_apply=True, no_prompts=True)
1424+
1425+
# Instantiate new context to remove caches etc
1426+
new_context = Context(paths=tmp_path, config=config)
1427+
1428+
star_model = new_context.get_model("sqlmesh_example.star_model")
1429+
assert (
1430+
star_model.render_query_or_raise().sql()
1431+
== 'SELECT CAST("full_model"."item_id" AS INT) AS "item_id", CAST("full_model"."num_orders" AS BIGINT) AS "num_orders" FROM "db"."sqlmesh_example"."full_model" AS "full_model"'
1432+
)
1433+
1434+
new_plan = new_context.plan(no_prompts=True)
1435+
assert not new_plan.has_changes
1436+
assert not new_plan.new_snapshots
1437+
1438+
13941439
@pytest.mark.parametrize(
13951440
"context_fixture",
13961441
["sushi_context", "sushi_dbt_context", "sushi_test_dbt_context", "sushi_no_default_catalog"],

tests/utils/test_cache.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,33 @@ def test_optimized_query_cache(tmp_path: Path, mocker: MockerFixture):
4949
cache = OptimizedQueryCache(tmp_path)
5050

5151
assert not cache.with_optimized_query(model)
52+
53+
model._query_renderer._cache = []
54+
model._query_renderer._optimized_cache = None
55+
56+
assert cache.with_optimized_query(model)
57+
58+
assert not model._query_renderer._cache
59+
assert model._query_renderer._optimized_cache is not None
60+
61+
62+
def test_optimized_query_cache_missing_rendered_query(tmp_path: Path, mocker: MockerFixture):
63+
model = SqlModel(
64+
name="test_model",
65+
query=parse_one("SELECT a FROM tbl"),
66+
mapping_schema={"tbl": {"a": "int"}},
67+
)
68+
render_mock = mocker.patch.object(model._query_renderer, "render")
69+
render_mock.return_value = None
70+
71+
cache = OptimizedQueryCache(tmp_path)
72+
73+
assert not cache.with_optimized_query(model)
74+
75+
model._query_renderer._cache = []
76+
model._query_renderer._optimized_cache = None
77+
5278
assert cache.with_optimized_query(model)
79+
80+
assert model._query_renderer._cache == [None]
81+
assert model._query_renderer._optimized_cache is None

0 commit comments

Comments
 (0)