Skip to content

Commit fd60806

Browse files
authored
Include a model's CTEs when testing one of its CTEs (#920)
1 parent f72a237 commit fd60806

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

sqlmesh/core/test/definition.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class TestError(SQLMeshError):
1818

1919

2020
class ModelTest(unittest.TestCase):
21+
__test__ = False
2122
view_names: list[str] = []
2223

2324
def __init__(
@@ -156,7 +157,9 @@ def __init__(
156157
engine_adapter=engine_adapter,
157158
)
158159
# For tests we just use the model name for the table reference and we don't want to expand
159-
mapping = {name: _test_fixture_name(name) for name in models}
160+
mapping = {
161+
name: _test_fixture_name(name) for name in models.keys() | body.get("inputs", {}).keys()
162+
}
160163
if mapping:
161164
self.query = exp.replace_tables(self.query, mapping)
162165

@@ -174,8 +177,11 @@ def test_ctes(self) -> None:
174177
_raise_error(
175178
f"No CTE named {cte_name} found in model {self.model.name}", self.path
176179
)
180+
cte_query = self.ctes[cte_name].this
181+
for alias, cte in self.ctes.items():
182+
cte_query = cte_query.with_(alias, cte.this)
177183
expected_df = pd.DataFrame.from_records(value["rows"])
178-
actual_df = self.execute(self.ctes[cte_name].this)
184+
actual_df = self.execute(cte_query)
179185
self.assert_equal(expected_df, actual_df)
180186

181187
def runTest(self) -> None:

tests/core/test_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from sqlmesh.core.context import Context
6+
from sqlmesh.core.dialect import parse
7+
from sqlmesh.core.model import SqlModel, load_model
8+
from sqlmesh.core.test.definition import SqlModelTest
9+
from sqlmesh.utils.yaml import load as load_yaml
10+
11+
12+
def test_ctes(sushi_context: Context) -> None:
13+
model = t.cast(
14+
SqlModel,
15+
sushi_context.upsert_model(
16+
load_model(
17+
parse(
18+
"""
19+
MODEL (
20+
name sushi.foo,
21+
kind FULL,
22+
);
23+
24+
WITH source AS (
25+
SELECT id FROM raw
26+
),
27+
renamed AS (
28+
SELECT id as fid FROM source
29+
)
30+
SELECT fid FROM renamed;
31+
"""
32+
)
33+
)
34+
),
35+
)
36+
37+
body = load_yaml(
38+
"""
39+
test_foo:
40+
model: sushi.foo
41+
inputs:
42+
raw:
43+
rows:
44+
- id: 1
45+
outputs:
46+
ctes:
47+
source:
48+
rows:
49+
- id: 1
50+
renamed:
51+
rows:
52+
- fid: 1
53+
query:
54+
rows:
55+
- fid: 1
56+
vars:
57+
start: 2022-01-01
58+
end: 2022-01-01
59+
latest: 2022-01-01
60+
"""
61+
)
62+
result = SqlModelTest(
63+
body=body["test_foo"],
64+
test_name="test_foo",
65+
model=model,
66+
models=sushi_context._models,
67+
engine_adapter=sushi_context._test_engine_adapter,
68+
path=None,
69+
).run()
70+
assert result and result.wasSuccessful()

0 commit comments

Comments
 (0)