Skip to content

Commit 1e6e318

Browse files
authored
Fix: Deserialization of the when_matched attribute of the INCREMENTAL_BY_UNIQUE_KEY kind (#1851)
1 parent 5e360a4 commit 1e6e318

File tree

3 files changed

+45
-6
lines changed

3 files changed

+45
-6
lines changed

sqlmesh/core/dialect.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp
1212
from sqlglot.dialects.dialect import DialectType
1313
from sqlglot.dialects.snowflake import Snowflake
14+
from sqlglot.helper import seq_get
1415
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1516
from sqlglot.optimizer.scope import traverse_scope
1617
from sqlglot.tokens import Token
@@ -607,8 +608,10 @@ class ChunkType(Enum):
607608
SQL = auto()
608609

609610

610-
def parse_one(sql: str, dialect: t.Optional[str] = None) -> exp.Expression:
611-
expressions = parse(sql, default_dialect=dialect, match_dialect=False)
611+
def parse_one(
612+
sql: str, dialect: t.Optional[str] = None, into: t.Optional[exp.IntoType] = None
613+
) -> exp.Expression:
614+
expressions = parse(sql, default_dialect=dialect, match_dialect=False, into=into)
612615
if not expressions:
613616
raise SQLMeshError(f"No expressions found in '{sql}'")
614617
elif len(expressions) > 1:
@@ -617,7 +620,10 @@ def parse_one(sql: str, dialect: t.Optional[str] = None) -> exp.Expression:
617620

618621

619622
def parse(
620-
sql: str, default_dialect: t.Optional[str] = None, match_dialect: bool = True
623+
sql: str,
624+
default_dialect: t.Optional[str] = None,
625+
match_dialect: bool = True,
626+
into: t.Optional[exp.IntoType] = None,
621627
) -> t.List[exp.Expression]:
622628
"""Parse a sql string.
623629
@@ -668,7 +674,10 @@ def parse(
668674

669675
for chunk, chunk_type in chunks:
670676
if chunk_type == ChunkType.SQL:
671-
for expression in parser.parse(chunk, sql):
677+
parsed_expressions: t.List[t.Optional[exp.Expression]] = (
678+
parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql)
679+
)
680+
for expression in parsed_expressions:
672681
if expression:
673682
expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1])
674683
expressions.append(expression)
@@ -706,6 +715,10 @@ def extend_sqlglot() -> None:
706715
parser.QUERY_MODIFIER_PARSERS.update(
707716
{TokenType.PARAMETER: lambda self: _parse_body_macro(self)}
708717
)
718+
# FIXME: Delete the extension below after upgrading to SQLGlot >= 20.3.0.
719+
parser.EXPRESSION_PARSERS.update(
720+
{exp.When: lambda self: seq_get(self._parse_when_matched(), 0)}
721+
)
709722

710723
for generator in generators:
711724
if MacroFunc not in generator.TRANSFORMS:

sqlmesh/core/model/kind.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class IncrementalByUniqueKeyKind(_Incremental):
231231
@field_validator("when_matched", mode="before")
232232
@field_validator_v1_args
233233
def _when_matched_validator(
234-
cls, v: t.Optional[exp.When], values: t.Dict[str, t.Any]
234+
cls, v: t.Optional[t.Union[exp.When, str]], values: t.Dict[str, t.Any]
235235
) -> t.Optional[exp.When]:
236236
def replace_table_references(expression: exp.Expression) -> exp.Expression:
237237
from sqlmesh.core.engine_adapter.base import (
@@ -252,9 +252,12 @@ def replace_table_references(expression: exp.Expression) -> exp.Expression:
252252
)
253253
return expression
254254

255+
if isinstance(v, str):
256+
return t.cast(exp.When, d.parse_one(v, into=exp.When))
257+
255258
if not v:
256259
return v
257-
v.meta["dialect"] = values.get("dialect")
260+
258261
return v.transform(replace_table_references)
259262

260263

tests/core/test_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,3 +2556,26 @@ def test_null_column_type():
25562556
"id": exp.DataType.build("int"),
25572557
}
25582558
assert not model.annotated
2559+
2560+
2561+
def test_when_matched():
2562+
expressions = d.parse(
2563+
"""
2564+
MODEL (
2565+
name db.employees,
2566+
kind INCREMENTAL_BY_UNIQUE_KEY (
2567+
unique_key name,
2568+
when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
2569+
)
2570+
);
2571+
SELECT 'name' AS name, 1 AS salary;
2572+
"""
2573+
)
2574+
2575+
expected_when_matched = "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)"
2576+
2577+
model = load_sql_based_model(expressions, dialect="hive")
2578+
assert model.kind.when_matched.sql() == expected_when_matched
2579+
2580+
model = SqlModel.parse_raw(model.json())
2581+
assert model.kind.when_matched.sql() == expected_when_matched

0 commit comments

Comments
 (0)