Skip to content

Commit f96488c

Browse files
authored
feat: add batch_size support to scd type 2 kinds (#4220)
1 parent 98519c2 commit f96488c

File tree

8 files changed

+160
-17
lines changed

8 files changed

+160
-17
lines changed

docs/concepts/models/model_kinds.md

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,12 +1241,13 @@ This is the most accurate representation of the menu based on the source data pr
12411241

12421242
### Shared Configuration Options
12431243

1244-
| Name | Description | Type |
1245-
|-------------------------|-----------------------------------------------------------------------------------------------------------------|---------------------------|
1246-
| unique_key | Unique key used for identifying rows between source and target | List of strings or string |
1247-
| valid_from_name | The name of the `valid_from` column to create in the target table. Default: `valid_from` | string |
1248-
| valid_to_name | The name of the `valid_to` column to create in the target table. Default: `valid_to` | string |
1249-
| invalidate_hard_deletes | If set to `true`, when a record is missing from the source table it will be marked as invalid. Default: `false` | bool |
1244+
| Name | Description | Type |
1245+
|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
1246+
| unique_key | Unique key used for identifying rows between source and target | List of strings or string |
1247+
| valid_from_name | The name of the `valid_from` column to create in the target table. Default: `valid_from` | string |
1248+
| valid_to_name | The name of the `valid_to` column to create in the target table. Default: `valid_to` | string |
1249+
| invalidate_hard_deletes | If set to `true`, when a record is missing from the source table it will be marked as invalid. Default: `false` | bool |
1250+
| batch_size | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. See [Processing Source Table with Historical Data](#processing-source-table-with-historical-data) for more info on this use case. (Default: `None`) | int |
12501251

12511252
!!! tip "Important"
12521253

@@ -1273,10 +1274,66 @@ This is the most accurate representation of the menu based on the source data pr
12731274

12741275
### SCD Type 2 By Column Configuration Options
12751276

1276-
| Name | Description | Type |
1277-
|------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
1278-
| columns | The name of the columns to check for changes. `*` to represent that all columns should be checked. | List of strings or string |
1279-
| execution_time_as_valid_from | By default, when the model is first loaded `valid_from` is set to `1970-01-01 00:00:00` and future new rows will have `execution_time` of when the pipeline ran. This changes the behavior to always use `execution_time`. Default: `false` | bool |
1277+
| Name | Description | Type |
1278+
|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------|
1279+
| columns | The name of the columns to check for changes. `*` to represent that all columns should be checked. | List of strings or string |
1280+
| execution_time_as_valid_from | By default, when the model is first loaded `valid_from` is set to `1970-01-01 00:00:00` and future new rows will have `execution_time` of when the pipeline ran. This changes the behavior to always use `execution_time`. Default: `false` | bool |
1281+
| updated_at_name | If sourcing from a table that includes as timestamp to use as valid_from, set this property to that column. See [Processing Source Table with Historical Data](#processing-source-table-with-historical-data) for more info on this use case. (Default: `None`) | int |
1282+
1283+
1284+
### Processing Source Table with Historical Data
1285+
1286+
The most common case for SCD Type 2 is creating history for a table that it doesn't have it already.
1287+
In the example of the restaurant menu, the menu just tells you what is offered right now, but you want to know what was offered over time.
1288+
In this case, the default setting of `None` for `batch_size` is the best option.
1289+
1290+
Another use case though is processing a source table that already has history in it.
1291+
A common example of this is a "daily snapshot" table that is created by a source system that takes a snapshot of the data at the end of each day.
1292+
If your source table has historical records, like a "daily snapshot" table, then set `batch_size` to `1` to process each interval (each day if a `@daily` cron) in sequential order.
1293+
That way the historical records will be properly captured in the SCD Type 2 table.
1294+
1295+
#### Example - Source from Daily Snapshot Table
1296+
1297+
```sql linenums="1"
1298+
MODEL (
1299+
name db.table,
1300+
kind SCD_TYPE_2_BY_COLUMN (
1301+
unique_key id,
1302+
columns [some_value],
1303+
updated_at_name ds,
1304+
batch_size 1
1305+
),
1306+
start '2025-01-01',
1307+
cron '@daily'
1308+
);
1309+
SELECT
1310+
id,
1311+
some_value,
1312+
ds
1313+
FROM
1314+
source_table
1315+
WHERE
1316+
ds between @start_ds and @end_ds
1317+
```
1318+
1319+
This will process each day of the source table in sequential order (if more than one day to process), checking `some_value` column to see if it changed. If it did change, `valid_from` will be set to match the `ds` column (except for first value which would be `1970-01-01 00:00:00`).
1320+
1321+
If the source data was the following:
1322+
1323+
| id | some_value | ds |
1324+
|----|------------|:-----------:|
1325+
| 1 | 1 | 2025-01-01 |
1326+
| 1 | 2 | 2025-01-02 |
1327+
| 1 | 3 | 2025-01-03 |
1328+
| 1 | 3 | 2025-01-04 |
1329+
1330+
Then the resulting SCD Type 2 table would be:
1331+
1332+
| id | some_value | ds | valid_from | valid_to |
1333+
|----|------------|:-----------:|:-------------------:|:-------------------:|
1334+
| 1 | 1 | 2025-01-01 | 1970-01-01 00:00:00 | 2025-01-02 00:00:00 |
1335+
| 1 | 2 | 2025-01-02 | 2025-01-02 00:00:00 | 2025-01-03 00:00:00 |
1336+
| 1 | 3 | 2025-01-03 | 2025-01-03 00:00:00 | NULL |
12801337

12811338
### Querying SCD Type 2 Models
12821339

sqlmesh/core/engine_adapter/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,7 +1411,7 @@ def scd_type_2_by_time(
14111411
unique_key: t.Sequence[exp.Expression],
14121412
valid_from_col: exp.Column,
14131413
valid_to_col: exp.Column,
1414-
execution_time: TimeLike,
1414+
execution_time: t.Union[TimeLike, exp.Column],
14151415
updated_at_col: exp.Column,
14161416
invalidate_hard_deletes: bool = True,
14171417
updated_at_as_valid_from: bool = False,
@@ -1445,7 +1445,7 @@ def scd_type_2_by_column(
14451445
unique_key: t.Sequence[exp.Expression],
14461446
valid_from_col: exp.Column,
14471447
valid_to_col: exp.Column,
1448-
execution_time: TimeLike,
1448+
execution_time: t.Union[TimeLike, exp.Column],
14491449
check_columns: t.Union[exp.Star, t.Sequence[exp.Column]],
14501450
invalidate_hard_deletes: bool = True,
14511451
execution_time_as_valid_from: bool = False,
@@ -1479,7 +1479,7 @@ def _scd_type_2(
14791479
unique_key: t.Sequence[exp.Expression],
14801480
valid_from_col: exp.Column,
14811481
valid_to_col: exp.Column,
1482-
execution_time: TimeLike,
1482+
execution_time: t.Union[TimeLike, exp.Column],
14831483
invalidate_hard_deletes: bool = True,
14841484
updated_at_col: t.Optional[exp.Column] = None,
14851485
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
@@ -1554,7 +1554,11 @@ def remove_managed_columns(
15541554
# column names and then remove them from the unmanaged_columns
15551555
if check_columns and check_columns == exp.Star():
15561556
check_columns = [exp.column(col) for col in unmanaged_columns_to_types]
1557-
execution_ts = to_time_column(execution_time, time_data_type, self.dialect, nullable=True)
1557+
execution_ts = (
1558+
exp.cast(execution_time, time_data_type, dialect=self.dialect)
1559+
if isinstance(execution_time, exp.Column)
1560+
else to_time_column(execution_time, time_data_type, self.dialect, nullable=True)
1561+
)
15581562
if updated_at_as_valid_from:
15591563
if not updated_at_col:
15601564
raise SQLMeshError(

sqlmesh/core/engine_adapter/trino.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _scd_type_2(
228228
unique_key: t.Sequence[exp.Expression],
229229
valid_from_col: exp.Column,
230230
valid_to_col: exp.Column,
231-
execution_time: TimeLike,
231+
execution_time: t.Union[TimeLike, exp.Column],
232232
invalidate_hard_deletes: bool = True,
233233
updated_at_col: t.Optional[exp.Column] = None,
234234
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,

sqlmesh/core/model/kind.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,7 @@ class _SCDType2Kind(_Incremental):
672672
valid_to_name: SQLGlotColumn = Field(exp.column("valid_to"), validate_default=True)
673673
invalidate_hard_deletes: SQLGlotBool = False
674674
time_data_type: exp.DataType = Field(exp.DataType.build("TIMESTAMP"), validate_default=True)
675+
batch_size: t.Optional[SQLGlotPositiveInt] = None
675676

676677
forward_only: SQLGlotBool = True
677678
disable_restatement: SQLGlotBool = True
@@ -711,6 +712,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
711712
gen(self.valid_to_name),
712713
str(self.invalidate_hard_deletes),
713714
gen(self.time_data_type),
715+
gen(self.batch_size) if self.batch_size is not None else None,
714716
]
715717

716718
@property
@@ -781,6 +783,7 @@ class SCDType2ByColumnKind(_SCDType2Kind):
781783
name: t.Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN
782784
columns: SQLGlotListOfColumnsOrStar
783785
execution_time_as_valid_from: SQLGlotBool = False
786+
updated_at_name: t.Optional[SQLGlotColumn] = None
784787

785788
@property
786789
def data_hash_values(self) -> t.List[t.Optional[str]]:
@@ -789,7 +792,12 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
789792
if isinstance(self.columns, list)
790793
else [gen(self.columns)]
791794
)
792-
return [*super().data_hash_values, *columns_sql, str(self.execution_time_as_valid_from)]
795+
return [
796+
*super().data_hash_values,
797+
*columns_sql,
798+
str(self.execution_time_as_valid_from),
799+
gen(self.updated_at_name) if self.updated_at_name is not None else None,
800+
]
793801

794802
def to_expression(
795803
self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any

sqlmesh/core/snapshot/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,7 @@ def insert(
17611761
unique_key=model.unique_key,
17621762
valid_from_col=model.kind.valid_from_name,
17631763
valid_to_col=model.kind.valid_to_name,
1764-
execution_time=kwargs["execution_time"],
1764+
execution_time=model.kind.updated_at_name or kwargs["execution_time"],
17651765
check_columns=model.kind.columns,
17661766
invalidate_hard_deletes=model.kind.invalidate_hard_deletes,
17671767
execution_time_as_valid_from=model.kind.execution_time_as_valid_from,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Add batch_size to SCD Type 2 models and add updated_at_name to by time which changes their data hash."""
2+
3+
4+
def migrate(state_sync, **kwargs): # type: ignore
5+
pass

tests/core/test_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4403,6 +4403,7 @@ def test_scd_type_2_by_column_overrides():
44034403
forward_only False,
44044404
disable_restatement False,
44054405
invalidate_hard_deletes False,
4406+
batch_size 1
44064407
),
44074408
);
44084409
SELECT
@@ -4428,6 +4429,7 @@ def test_scd_type_2_by_column_overrides():
44284429
assert scd_type_2_model.kind.is_scd_type_2
44294430
assert scd_type_2_model.kind.is_materialized
44304431
assert scd_type_2_model.kind.time_data_type == exp.DataType.build("TIMESTAMPTZ")
4432+
assert scd_type_2_model.kind.batch_size == 1
44314433
assert not scd_type_2_model.kind.invalidate_hard_deletes
44324434
assert not scd_type_2_model.kind.forward_only
44334435
assert not scd_type_2_model.kind.disable_restatement

tests/core/test_scheduler.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
IncrementalByTimeRangeKind,
1414
IncrementalByUniqueKeyKind,
1515
TimeColumn,
16+
SCDType2ByColumnKind,
1617
)
1718
from sqlmesh.core.node import IntervalUnit
1819
from sqlmesh.core.scheduler import (
@@ -810,3 +811,69 @@ def signal_base(batch: DatetimeRanges):
810811
snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))],
811812
snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
812813
}
814+
815+
816+
@pytest.mark.parametrize(
817+
"batch_size, expected_batches",
818+
[
819+
(
820+
1,
821+
[
822+
(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
823+
(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")),
824+
(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")),
825+
],
826+
),
827+
(
828+
None,
829+
[
830+
(to_timestamp("2023-01-01"), to_timestamp("2023-01-04")),
831+
],
832+
),
833+
],
834+
)
835+
def test_scd_type_2_batch_size(
836+
mocker: MockerFixture,
837+
make_snapshot,
838+
get_batched_missing_intervals,
839+
batch_size: t.Optional[int],
840+
expected_batches: t.List[t.Tuple[int, int]],
841+
):
842+
"""
843+
Test that SCD_TYPE_2_BY_COLUMN models are batched correctly based on batch_size.
844+
With batch_size=1, we expect 3 separate batches for 3 days.
845+
Without a specified batch_size, we expect a single batch for the entire period.
846+
"""
847+
start = to_datetime("2023-01-01")
848+
end = to_datetime("2023-01-04")
849+
850+
# Configure kind params
851+
kind_params = {}
852+
if batch_size is not None:
853+
kind_params["batch_size"] = batch_size
854+
855+
# Create the model and snapshot
856+
model = SqlModel(
857+
name="test_scd_model",
858+
kind=SCDType2ByColumnKind(columns="valid_to", unique_key=["id"], **kind_params),
859+
cron="@daily",
860+
start=start,
861+
query=parse_one("SELECT id, valid_from, valid_to FROM source"),
862+
)
863+
snapshot = make_snapshot(model)
864+
865+
# Setup scheduler
866+
snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1)
867+
scheduler = Scheduler(
868+
snapshots=[snapshot],
869+
snapshot_evaluator=snapshot_evaluator,
870+
state_sync=mocker.MagicMock(),
871+
max_workers=2,
872+
default_catalog=None,
873+
)
874+
875+
# Get batches for the time period
876+
batches = get_batched_missing_intervals(scheduler, start, end, end)[snapshot]
877+
878+
# Verify batches match expectations
879+
assert batches == expected_batches

0 commit comments

Comments
 (0)