Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 15 additions & 42 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,11 @@ def insert(
column_descriptions=model.column_descriptions,
truncate=is_first_insert,
source_columns=source_columns,
storage_format=model.storage_format,
partitioned_by=model.partitioned_by,
partition_interval_unit=model.partition_interval_unit,
clustered_by=model.clustered_by,
table_properties=kwargs.get("physical_properties", model.physical_properties),
)
elif isinstance(model.kind, SCDType2ByColumnKind):
self.adapter.scd_type_2_by_column(
Expand All @@ -2259,6 +2264,11 @@ def insert(
column_descriptions=model.column_descriptions,
truncate=is_first_insert,
source_columns=source_columns,
storage_format=model.storage_format,
partitioned_by=model.partitioned_by,
partition_interval_unit=model.partition_interval_unit,
clustered_by=model.clustered_by,
table_properties=kwargs.get("physical_properties", model.physical_properties),
)
else:
raise SQLMeshError(
Expand All @@ -2273,51 +2283,14 @@ def append(
render_kwargs: t.Dict[str, t.Any],
**kwargs: t.Any,
) -> None:
# Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models.
columns_to_types, source_columns = self._get_target_and_source_columns(
model,
return self.insert(
table_name,
query_or_df,
model,
is_first_insert=False,
render_kwargs=render_kwargs,
force_get_columns_from_target=True,
**kwargs,
)
Comment on lines +2286 to 2293
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An append in this context seems to just be an insert where is_first_insert is False so updated it to reflect that and remove duplicate code.

if isinstance(model.kind, SCDType2ByTimeKind):
self.adapter.scd_type_2_by_time(
target_table=table_name,
source_table=query_or_df,
unique_key=model.unique_key,
valid_from_col=model.kind.valid_from_name,
valid_to_col=model.kind.valid_to_name,
updated_at_col=model.kind.updated_at_name,
invalidate_hard_deletes=model.kind.invalidate_hard_deletes,
updated_at_as_valid_from=model.kind.updated_at_as_valid_from,
target_columns_to_types=columns_to_types,
table_format=model.table_format,
table_description=model.description,
column_descriptions=model.column_descriptions,
source_columns=source_columns,
**kwargs,
)
elif isinstance(model.kind, SCDType2ByColumnKind):
self.adapter.scd_type_2_by_column(
target_table=table_name,
source_table=query_or_df,
unique_key=model.unique_key,
valid_from_col=model.kind.valid_from_name,
valid_to_col=model.kind.valid_to_name,
check_columns=model.kind.columns,
target_columns_to_types=columns_to_types,
table_format=model.table_format,
invalidate_hard_deletes=model.kind.invalidate_hard_deletes,
execution_time_as_valid_from=model.kind.execution_time_as_valid_from,
table_description=model.description,
column_descriptions=model.column_descriptions,
source_columns=source_columns,
**kwargs,
)
else:
raise SQLMeshError(
f"Unexpected SCD Type 2 kind: {model.kind}. This is not expected and please report this as a bug."
)


class ViewStrategy(PromotableStrategy):
Expand Down
34 changes: 34 additions & 0 deletions tests/core/engine_adapter/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
import typing as t
from datetime import datetime

import pandas as pd # noqa: TID253
import pytest
Expand Down Expand Up @@ -1173,3 +1174,36 @@ def test_drop_cascade(adapter: BigQueryEngineAdapter):
"DROP SCHEMA IF EXISTS `foo` CASCADE",
"DROP SCHEMA IF EXISTS `foo`",
]


def test_scd_type_2_by_partitioning(adapter: BigQueryEngineAdapter):
adapter.scd_type_2_by_time(
target_table="target",
source_table=t.cast(
exp.Select, parse_one("SELECT id, name, price, test_UPDATED_at FROM source")
),
unique_key=[
exp.to_column("id"),
],
updated_at_col=exp.column("test_UPDATED_at", quoted=True),
valid_from_col=exp.to_column("valid_from", quoted=True),
valid_to_col=exp.to_column("valid_to", quoted=True),
target_columns_to_types={
"id": exp.DataType.build("INT"),
"name": exp.DataType.build("VARCHAR"),
"price": exp.DataType.build("DOUBLE"),
"test_UPDATED_at": exp.DataType.build("TIMESTAMP"),
"valid_from": exp.DataType.build("TIMESTAMP"),
"valid_to": exp.DataType.build("TIMESTAMP"),
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
partitioned_by=[parse_one("TIMESTAMP_TRUNC(valid_from, DAY)")],
)

calls = _to_sql_calls(adapter)

# Initial call to create the table and then another to replace since it is self-referencing
assert len(calls) == 2
# Both calls should contain the partition logic (the scd logic is already covered by other tests)
assert "PARTITION BY TIMESTAMP_TRUNC(`valid_from`, DAY)" in calls[0]
assert "PARTITION BY TIMESTAMP_TRUNC(`valid_from`, DAY)" in calls[1]
53 changes: 50 additions & 3 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import typing as t

from typing_extensions import Self
from unittest.mock import call, patch, Mock
import re
Expand Down Expand Up @@ -2062,7 +2063,7 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot):
)


def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot):
def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot, mocker):
evaluator = SnapshotEvaluator(adapter_mock)
model = load_sql_based_model(
parse( # type: ignore
Expand All @@ -2073,7 +2074,8 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot):
unique_key id,
time_data_type TIMESTAMPTZ,
invalidate_hard_deletes false
)
),
partitioned_by cola
);

SELECT * FROM tbl;
Expand All @@ -2086,6 +2088,7 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot):

evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable())

source_query = parse_one('SELECT * FROM "tbl" AS "tbl"')
query = parse_one(
"""SELECT *, CAST(NULL AS TIMESTAMPTZ) AS valid_from, CAST(NULL AS TIMESTAMPTZ) AS valid_to FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0"""
)
Expand All @@ -2094,7 +2097,9 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot):
common_kwargs = dict(
table_format=None,
storage_format=None,
partitioned_by=[],
partitioned_by=[
exp.to_column("cola", quoted=True),
],
partition_interval_unit=None,
clustered_by=[],
table_properties={},
Expand All @@ -2113,6 +2118,38 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot):
]
)

adapter_mock.reset_mock()

evaluator.evaluate(
snapshot,
start="2020-01-01",
end="2020-01-02",
execution_time="2020-01-02",
snapshots={},
deployability_index=DeployabilityIndex.none_deployable(),
)

adapter_mock.scd_type_2_by_time.assert_has_calls(
[
call(
column_descriptions={},
execution_time="2020-01-02",
invalidate_hard_deletes=False,
source_columns=None,
source_table=source_query,
target_columns_to_types=mocker.ANY,
target_table=snapshot.table_name(is_deployable=False),
truncate=True,
unique_key=[exp.to_column("id", quoted=True)],
updated_at_as_valid_from=False,
updated_at_col=exp.column("updated_at", quoted=True),
valid_from_col=exp.column("valid_from", quoted=True),
valid_to_col=exp.column("valid_to", quoted=True),
**common_kwargs,
),
]
)


@pytest.mark.parametrize(
"intervals,truncate",
Expand Down Expand Up @@ -2178,6 +2215,11 @@ def test_insert_into_scd_type_2_by_time(
updated_at_as_valid_from=False,
truncate=truncate,
source_columns=None,
clustered_by=[],
partition_interval_unit=None,
partitioned_by=[],
storage_format=None,
table_properties={},
)
adapter_mock.columns.assert_called_once_with(snapshot.table_name())

Expand Down Expand Up @@ -2347,6 +2389,11 @@ def test_insert_into_scd_type_2_by_column(
column_descriptions={},
truncate=truncate,
source_columns=None,
clustered_by=[],
partition_interval_unit=None,
partitioned_by=[],
storage_format=None,
table_properties={},
)
adapter_mock.columns.assert_called_once_with(snapshot.table_name())

Expand Down