Skip to content

Commit d6af3f9

Browse files
Fix: normalization issue in table diff (#2872)
1 parent 89bbdb3 commit d6af3f9

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

sqlmesh/core/context.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ def table_diff(
12991299
The TableDiff object containing schema and summary differences.
13001300
"""
13011301
source_alias, target_alias = source, target
1302+
13021303
if model_or_snapshot:
13031304
model = self.get_model(model_or_snapshot, raise_if_missing=True)
13041305
source_env = self.state_reader.get_environment(source)
@@ -1321,7 +1322,13 @@ def table_diff(
13211322
if not on:
13221323
for ref in model.all_references:
13231324
if ref.unique:
1324-
on = ref.columns
1325+
expr = ref.expression
1326+
1327+
if isinstance(expr, exp.Tuple):
1328+
on = [key.this.sql() for key in expr.expressions]
1329+
else:
1330+
# Handle a single Column or Paren expression
1331+
on = [expr.this.sql()]
13251332

13261333
if not on:
13271334
raise SQLMeshError(
@@ -1337,6 +1344,7 @@ def table_diff(
13371344
source_alias=source_alias,
13381345
target_alias=target_alias,
13391346
model_name=model.name if model_or_snapshot else None,
1347+
model_dialect=model.dialect if model_or_snapshot else None,
13401348
limit=limit,
13411349
decimals=decimals,
13421350
)

sqlmesh/core/table_diff.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
source_alias: t.Optional[str] = None,
145145
target_alias: t.Optional[str] = None,
146146
model_name: t.Optional[str] = None,
147+
model_dialect: t.Optional[str] = None,
147148
decimals: int = 3,
148149
):
149150
self.adapter = adapter
@@ -153,13 +154,15 @@ def __init__(
153154
self.where = exp.condition(where, dialect=self.dialect) if where else None
154155
self.limit = limit
155156
self.model_name = model_name
157+
self.model_dialect = model_dialect
156158
self.decimals = decimals
157159

158160
# Support environment aliases for diff output improvement in certain cases
159161
self.source_alias = source_alias
160162
self.target_alias = target_alias
161163

162164
if isinstance(on, (list, tuple)):
165+
join_condition = [exp.parse_identifier(key) for key in on]
163166
s_table = exp.to_identifier("s", quoted=True)
164167
t_table = exp.to_identifier("t", quoted=True)
165168

@@ -170,13 +173,13 @@ def __init__(
170173
exp.column(c, s_table).is_(exp.null())
171174
& exp.column(c, t_table).is_(exp.null())
172175
)
173-
for c in on
176+
for c in join_condition
174177
)
175178
)
176179
else:
177180
self.on = on
178181

179-
normalize_identifiers(self.on, dialect=self.dialect)
182+
normalize_identifiers(self.on, dialect=self.model_dialect or self.dialect)
180183

181184
self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None
182185
self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None
@@ -321,7 +324,7 @@ def name(e: exp.Expression) -> str:
321324
.as_("row_full_match"),
322325
).from_(query.subquery("stats"))
323326

324-
query = quote_identifiers(query, dialect=self.dialect)
327+
query = quote_identifiers(query, dialect=self.model_dialect or self.dialect)
325328
temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True)
326329

327330
with self.adapter.temp_table(query, name=temp_table) as table:

tests/core/test_table_diff.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,10 @@ def test_data_diff_decimals(sushi_context_fixed_date):
142142
def test_grain_check(sushi_context_fixed_date):
143143
expressions = d.parse(
144144
"""
145-
MODEL (name memory.sushi.grain_items, kind full, grain(key_1, key_2));
145+
MODEL (name memory.sushi.grain_items, kind full, grain("key_1", KEY_2));
146146
SELECT
147-
key_1,
148-
key_2,
147+
key_1 as "key_1",
148+
KEY_2,
149149
value,
150150
FROM
151151
(VALUES
@@ -156,10 +156,10 @@ def test_grain_check(sushi_context_fixed_date):
156156
(1, 2, 2),
157157
(4, NULL, 3),
158158
(2, 3, 2),
159-
) AS t (key_1,key_2, value)
159+
) AS t (key_1,KEY_2, value)
160160
"""
161161
)
162-
model_s = load_sql_based_model(expressions)
162+
model_s = load_sql_based_model(expressions, dialect="snowflake")
163163
sushi_context_fixed_date.upsert_model(model_s)
164164
sushi_context_fixed_date.plan(
165165
"source_dev",
@@ -170,14 +170,14 @@ def test_grain_check(sushi_context_fixed_date):
170170
end="2023-01-31",
171171
)
172172

173-
model = sushi_context_fixed_date.models['"memory"."sushi"."grain_items"']
173+
model = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."GRAIN_ITEMS"']
174174

175175
modified_model = model.dict()
176176
modified_model["query"] = (
177177
exp.select("*")
178178
.from_(model.query.subquery())
179179
.union(
180-
"SELECT key_1, key_2, value FROM (VALUES (1, 6, 1),(1, 5, 3),(NULL, 2, 3),) AS t (key_1, key_2, value)"
180+
'SELECT key_1 as "key_1", KEY_2, value FROM (VALUES (1, 6, 1),(1, 5, 3),(NULL, 2, 3),) AS t (key_1, KEY_2, value)'
181181
)
182182
)
183183

@@ -200,8 +200,8 @@ def test_grain_check(sushi_context_fixed_date):
200200
diff = sushi_context_fixed_date.table_diff(
201201
source="source_dev",
202202
target="target_dev",
203-
on=["key_1", "key_2"],
204-
model_or_snapshot="sushi.grain_items",
203+
on=["'key_1'", "key_2"],
204+
model_or_snapshot="SUSHI.GRAIN_ITEMS",
205205
skip_grain_check=False,
206206
)
207207

0 commit comments

Comments
 (0)