Skip to content

Commit 41ae931

Browse files
Fix!: Allow expressions in scd type 2 models columns, check_cols in dbt
1 parent 5ef3125 commit 41ae931

File tree

9 files changed

+81
-16
lines changed

9 files changed

+81
-16
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,7 @@ def scd_type_2_by_column(
17721772
valid_from_col: exp.Column,
17731773
valid_to_col: exp.Column,
17741774
execution_time: t.Union[TimeLike, exp.Column],
1775-
check_columns: t.Union[exp.Star, t.Sequence[exp.Column]],
1775+
check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]],
17761776
invalidate_hard_deletes: bool = True,
17771777
execution_time_as_valid_from: bool = False,
17781778
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1810,7 +1810,7 @@ def _scd_type_2(
18101810
execution_time: t.Union[TimeLike, exp.Column],
18111811
invalidate_hard_deletes: bool = True,
18121812
updated_at_col: t.Optional[exp.Column] = None,
1813-
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
1813+
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
18141814
updated_at_as_valid_from: bool = False,
18151815
execution_time_as_valid_from: bool = False,
18161816
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,

sqlmesh/core/engine_adapter/trino.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _scd_type_2(
302302
execution_time: t.Union[TimeLike, exp.Column],
303303
invalidate_hard_deletes: bool = True,
304304
updated_at_col: t.Optional[exp.Column] = None,
305-
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
305+
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None,
306306
updated_at_as_valid_from: bool = False,
307307
execution_time_as_valid_from: bool = False,
308308
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,

sqlmesh/core/model/kind.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
PydanticModel,
2424
SQLGlotBool,
2525
SQLGlotColumn,
26-
SQLGlotListOfColumnsOrStar,
26+
SQLGlotListOfFieldsOrStar,
2727
SQLGlotListOfFields,
2828
SQLGlotPositiveInt,
2929
SQLGlotString,
@@ -852,7 +852,7 @@ def to_expression(
852852

853853
class SCDType2ByColumnKind(_SCDType2Kind):
854854
name: t.Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN
855-
columns: SQLGlotListOfColumnsOrStar
855+
columns: SQLGlotListOfFieldsOrStar
856856
execution_time_as_valid_from: SQLGlotBool = False
857857
updated_at_name: t.Optional[SQLGlotColumn] = None
858858

sqlmesh/utils/pydantic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,13 @@ def column_validator(v: t.Any, values: t.Any) -> exp.Column:
289289
return expression
290290

291291

292-
def list_of_columns_or_star_validator(
292+
def list_of_fields_or_star_validator(
293293
v: t.Any, values: t.Any
294-
) -> t.Union[exp.Star, t.List[exp.Column]]:
294+
) -> t.Union[exp.Star, t.List[exp.Expression]]:
295295
expressions = _get_fields(v, values)
296296
if len(expressions) == 1 and isinstance(expressions[0], exp.Star):
297297
return t.cast(exp.Star, expressions[0])
298-
return t.cast(t.List[exp.Column], expressions)
298+
return t.cast(t.List[exp.Expression], expressions)
299299

300300

301301
def cron_validator(v: t.Any) -> str:
@@ -339,7 +339,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
339339
SQLGlotPositiveInt = int
340340
SQLGlotColumn = exp.Column
341341
SQLGlotListOfFields = t.List[exp.Expression]
342-
SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star]
342+
SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star]
343343
SQLGlotCron = str
344344
else:
345345
from pydantic.functional_validators import BeforeValidator
@@ -352,7 +352,7 @@ def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]:
352352
SQLGlotListOfFields = t.Annotated[
353353
t.List[exp.Expression], BeforeValidator(list_of_fields_validator)
354354
]
355-
SQLGlotListOfColumnsOrStar = t.Annotated[
356-
t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator)
355+
SQLGlotListOfFieldsOrStar = t.Annotated[
356+
t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator)
357357
]
358358
SQLGlotCron = t.Annotated[str, BeforeValidator(cron_validator)]

tests/core/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def test_scd_type_2_by_col_serde():
949949
model_json_parsed = json.loads(model.json())
950950
assert model_json_parsed["kind"]["dialect"] == "bigquery"
951951
assert model_json_parsed["kind"]["unique_key"] == ["`a`"]
952-
assert model_json_parsed["kind"]["columns"] == "*"
952+
assert model_json_parsed["kind"]["columns"] == ["*"]
953953
# Bigquery converts TIMESTAMP -> DATETIME
954954
assert model_json_parsed["kind"]["time_data_type"] == "DATETIME"
955955

@@ -5427,7 +5427,7 @@ def scd_type_2_model(context, **kwargs):
54275427
'["col1"]',
54285428
[exp.to_column("col1", quoted=True)],
54295429
),
5430-
("*", exp.Star()),
5430+
("*", [exp.Star()]),
54315431
],
54325432
)
54335433
def test_check_column_variants(input_columns, expected_columns):
@@ -8360,7 +8360,7 @@ def test_model_kind_to_expression():
83608360
.kind.to_expression()
83618361
.sql()
83628362
== """SCD_TYPE_2_BY_COLUMN (
8363-
columns *,
8363+
columns (*),
83648364
execution_time_as_valid_from FALSE,
83658365
unique_key ("a", "b"),
83668366
valid_from_name "valid_from",

tests/core/test_snapshot_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2490,7 +2490,7 @@ def test_insert_into_scd_type_2_by_column(
24902490
target_columns_to_types=table_columns,
24912491
table_format=None,
24922492
unique_key=[exp.to_column("id", quoted=True)],
2493-
check_columns=exp.Star(),
2493+
check_columns=[exp.Star()],
24942494
valid_from_col=exp.column("valid_from", quoted=True),
24952495
valid_to_col=exp.column("valid_to", quoted=True),
24962496
execution_time="2020-01-02",

tests/dbt/test_model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlglot.errors import SchemaError
88
from sqlmesh import Context
99
from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind
10-
from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange
10+
from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange, SCDType2ByColumnKind
1111
from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json
1212
from sqlmesh.core.config.common import VirtualEnvironmentMode
1313
from sqlmesh.core.model.meta import GrantsTargetLayer
@@ -705,6 +705,29 @@ def test_load_multiple_snapshots_defined_in_same_file(sushi_test_dbt_context: Co
705705
assert context.get_model("snapshots.items_check_snapshot")
706706

707707

708+
@pytest.mark.slow
709+
def test_dbt_snapshot_with_check_cols_expressions(sushi_test_dbt_context: Context) -> None:
710+
context = sushi_test_dbt_context
711+
model = context.get_model("snapshots.items_check_with_cast_snapshot")
712+
assert model is not None
713+
assert isinstance(model.kind, SCDType2ByColumnKind)
714+
715+
columns = model.kind.columns
716+
assert isinstance(columns, list)
717+
assert len(columns) == 1
718+
719+
# expression in check_cols is: ds::DATE
720+
assert isinstance(columns[0], exp.Cast)
721+
assert columns[0].sql() == 'CAST("ds" AS DATE)'
722+
723+
context.load()
724+
cached_model = context.get_model("snapshots.items_check_with_cast_snapshot")
725+
assert cached_model is not None
726+
assert isinstance(cached_model.kind, SCDType2ByColumnKind)
727+
assert isinstance(cached_model.kind.columns, list)
728+
assert len(cached_model.kind.columns) == 1
729+
730+
708731
@pytest.mark.slow
709732
def test_dbt_jinja_macro_undefined_variable_error(create_empty_project):
710733
project_dir, model_dir = create_empty_project()

tests/dbt/test_transformation.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,32 @@ def test_model_kind():
292292
on_additive_change=OnAdditiveChange.ALLOW,
293293
)
294294

295+
check_cols_with_cast = ModelConfig(
296+
materialized=Materialization.SNAPSHOT,
297+
unique_key=["id"],
298+
strategy="check",
299+
check_cols=["created_at::TIMESTAMPTZ"],
300+
).model_kind(context)
301+
assert isinstance(check_cols_with_cast, SCDType2ByColumnKind)
302+
assert check_cols_with_cast.execution_time_as_valid_from is True
303+
assert len(check_cols_with_cast.columns) == 1
304+
assert isinstance(check_cols_with_cast.columns[0], exp.Cast)
305+
assert check_cols_with_cast.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)'
306+
307+
check_cols_multiple_expr = ModelConfig(
308+
materialized=Materialization.SNAPSHOT,
309+
unique_key=["id"],
310+
strategy="check",
311+
check_cols=["created_at::TIMESTAMPTZ", "COALESCE(status, 'active')"],
312+
).model_kind(context)
313+
assert isinstance(check_cols_multiple_expr, SCDType2ByColumnKind)
314+
assert len(check_cols_multiple_expr.columns) == 2
315+
assert isinstance(check_cols_multiple_expr.columns[0], exp.Cast)
316+
assert isinstance(check_cols_multiple_expr.columns[1], exp.Coalesce)
317+
318+
assert check_cols_multiple_expr.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)'
319+
assert check_cols_multiple_expr.columns[1].sql() == "COALESCE(\"status\", 'active')"
320+
295321
assert ModelConfig(materialized=Materialization.INCREMENTAL, time_column="foo").model_kind(
296322
context
297323
) == IncrementalByTimeRangeKind(

tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,19 @@ select * from {{ source('streaming', 'items') }}
3030
select * from {{ source('streaming', 'items') }}
3131

3232
{% endsnapshot %}
33+
34+
{% snapshot items_check_with_cast_snapshot %}
35+
36+
{{
37+
config(
38+
target_schema='snapshots',
39+
unique_key='id',
40+
strategy='check',
41+
check_cols=['ds::DATE'],
42+
invalidate_hard_deletes=True,
43+
)
44+
}}
45+
46+
select * from {{ source('streaming', 'items') }}
47+
48+
{% endsnapshot %}

0 commit comments

Comments
 (0)