Skip to content

Commit 70c8deb

Browse files
authored
chore: add dbt microbatch interface (#5272)
1 parent 37523dc commit 70c8deb

File tree

3 files changed

+278
-69
lines changed

3 files changed

+278
-69
lines changed

sqlmesh/dbt/common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.An
132132
def config_attribute_dict(self) -> AttributeDict[str, t.Any]:
133133
return AttributeDict(self.dict(exclude=EXCLUDED_CONFIG_ATTRIBUTE_KEYS))
134134

135+
def _get_field_value(self, field: str) -> t.Optional[t.Any]:
136+
field_val = getattr(self, field, None)
137+
return field_val if field_val is not None else self.meta.get(field, None)
138+
135139
def replace(self, other: T) -> None:
136140
"""
137141
Replace the contents of this instance with the passed in instance.
@@ -152,9 +156,7 @@ def sqlmesh_config_kwargs(self) -> t.Dict[str, t.Any]:
152156
"""
153157
kwargs = {}
154158
for field in self.sqlmesh_config_fields:
155-
field_val = getattr(self, field, None)
156-
if field_val is None:
157-
field_val = self.meta.get(field, None)
159+
field_val = self._get_field_value(field)
158160
if field_val is not None:
159161
kwargs[field] = field_val
160162
return kwargs

sqlmesh/dbt/model.py

Lines changed: 102 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import datetime
34
import typing as t
45

56
from sqlglot import exp
@@ -34,7 +35,7 @@
3435
from sqlmesh.dbt.context import DbtContext
3536

3637

37-
INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite"])
38+
INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite", "microbatch"])
3839
INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES = set(["merge"])
3940

4041

@@ -73,9 +74,7 @@ class ModelConfig(BaseModelConfig):
7374
time_column: t.Optional[str] = None
7475
cron: t.Optional[str] = None
7576
interval_unit: t.Optional[str] = None
76-
batch_size: t.Optional[int] = None
7777
batch_concurrency: t.Optional[int] = None
78-
lookback: t.Optional[int] = None
7978
forward_only: bool = True
8079
disable_restatement: t.Optional[bool] = None
8180
allow_partials: t.Optional[bool] = None
@@ -100,6 +99,15 @@ class ModelConfig(BaseModelConfig):
10099
target_schema: t.Optional[str] = None
101100
check_cols: t.Optional[t.Union[t.List[str], str]] = None
102101

102+
# Microbatch Fields
103+
event_time: t.Optional[str] = None
104+
begin: t.Optional[datetime.datetime] = None
105+
concurrent_batches: t.Optional[bool] = None
106+
107+
# Shared SQLMesh and DBT configuration fields
108+
batch_size: t.Optional[t.Union[int, str]] = None
109+
lookback: t.Optional[int] = None
110+
103111
# redshift
104112
bind: t.Optional[bool] = None
105113

@@ -220,6 +228,17 @@ def snapshot_strategy(self) -> t.Optional[SnapshotStrategy]:
220228
def table_schema(self) -> str:
221229
return self.target_schema or super().table_schema
222230

231+
def _get_overlapping_field_value(
232+
self, context: DbtContext, dbt_field_name: str, sqlmesh_field_name: str
233+
) -> t.Optional[t.Any]:
234+
dbt_field = self._get_field_value(dbt_field_name)
235+
sqlmesh_field = getattr(self, sqlmesh_field_name, None)
236+
if dbt_field is not None and sqlmesh_field is not None:
237+
get_console().log_warning(
238+
f"Both '{dbt_field_name}' and '{sqlmesh_field_name}' are set for model '{self.canonical_name(context)}'. '{sqlmesh_field_name}' will be used."
239+
)
240+
return sqlmesh_field if sqlmesh_field is not None else dbt_field
241+
223242
def model_kind(self, context: DbtContext) -> ModelKind:
224243
"""
225244
Get the sqlmesh ModelKind
@@ -256,27 +275,44 @@ def model_kind(self, context: DbtContext) -> ModelKind:
256275

257276
incremental_kind_kwargs["on_destructive_change"] = on_destructive_change
258277
incremental_kind_kwargs["on_additive_change"] = on_additive_change
259-
for field in ("forward_only", "auto_restatement_cron"):
260-
field_val = getattr(self, field, None)
261-
if field_val is None:
262-
field_val = self.meta.get(field, None)
263-
if field_val is not None:
264-
incremental_kind_kwargs[field] = field_val
278+
auto_restatement_cron_value = self._get_field_value("auto_restatement_cron")
279+
if auto_restatement_cron_value is not None:
280+
incremental_kind_kwargs["auto_restatement_cron"] = auto_restatement_cron_value
265281

266282
if materialization == Materialization.TABLE:
267283
return FullKind()
268284
if materialization == Materialization.VIEW:
269285
return ViewKind()
270286
if materialization == Materialization.INCREMENTAL:
271287
incremental_by_kind_kwargs: t.Dict[str, t.Any] = {"dialect": self.dialect(context)}
288+
forward_only_value = self._get_field_value("forward_only")
289+
if forward_only_value is not None:
290+
incremental_kind_kwargs["forward_only"] = forward_only_value
291+
292+
is_incremental_by_time_range = self.time_column or (
293+
self.incremental_strategy and self.incremental_strategy == "microbatch"
294+
)
295+
# Get shared incremental by kwargs
272296
for field in ("batch_size", "batch_concurrency", "lookback"):
273-
field_val = getattr(self, field, None)
274-
if field_val is None:
275-
field_val = self.meta.get(field, None)
297+
field_val = self._get_field_value(field)
276298
if field_val is not None:
299+
# Check if `batch_size` is representing an interval unit and if so that will be handled at the model level
300+
if field == "batch_size" and isinstance(field_val, str):
301+
continue
277302
incremental_by_kind_kwargs[field] = field_val
278303

279-
if self.time_column:
304+
disable_restatement = self.disable_restatement
305+
if disable_restatement is None:
306+
if is_incremental_by_time_range:
307+
disable_restatement = False
308+
else:
309+
disable_restatement = (
310+
not self.full_refresh if self.full_refresh is not None else False
311+
)
312+
incremental_by_kind_kwargs["disable_restatement"] = disable_restatement
313+
314+
# Incremental by time range which includes microbatch
315+
if is_incremental_by_time_range:
280316
strategy = self.incremental_strategy or target.default_incremental_strategy(
281317
IncrementalByTimeRangeKind
282318
)
@@ -287,22 +323,37 @@ def model_kind(self, context: DbtContext) -> ModelKind:
287323
f"Supported strategies include {collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES)}."
288324
)
289325

326+
if strategy == "microbatch":
327+
time_column = self._get_overlapping_field_value(
328+
context, "event_time", "time_column"
329+
)
330+
if not time_column:
331+
raise ConfigError(
332+
f"{self.canonical_name(context)}: 'event_time' is required for microbatch incremental strategy."
333+
)
334+
concurrent_batches = self._get_field_value("concurrent_batches")
335+
if concurrent_batches is True:
336+
if incremental_by_kind_kwargs.get("batch_size"):
337+
get_console().log_warning(
338+
f"'concurrent_batches' is set to True and 'batch_size' are defined in '{self.canonical_name(context)}'. The batch size will be set to the value of `batch_size`."
339+
)
340+
incremental_by_kind_kwargs["batch_size"] = incremental_by_kind_kwargs.get(
341+
"batch_size", 1
342+
)
343+
else:
344+
if not self.time_column:
345+
raise ConfigError(
346+
f"{self.canonical_name(context)}: 'time_column' is required for incremental by time range models not defined using microbatch."
347+
)
348+
time_column = self.time_column
349+
290350
return IncrementalByTimeRangeKind(
291-
time_column=self.time_column,
292-
disable_restatement=(
293-
self.disable_restatement if self.disable_restatement is not None else False
294-
),
351+
time_column=time_column,
295352
auto_restatement_intervals=self.auto_restatement_intervals,
296353
**incremental_kind_kwargs,
297354
**incremental_by_kind_kwargs,
298355
)
299356

300-
disable_restatement = self.disable_restatement
301-
if disable_restatement is None:
302-
disable_restatement = (
303-
not self.full_refresh if self.full_refresh is not None else False
304-
)
305-
306357
if self.unique_key:
307358
strategy = self.incremental_strategy or target.default_incremental_strategy(
308359
IncrementalByUniqueKeyKind
@@ -315,11 +366,11 @@ def model_kind(self, context: DbtContext) -> ModelKind:
315366
f"Unique key is not compatible with '{strategy}' incremental strategy in model '{self.canonical_name(context)}'. "
316367
f"Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}. Falling back to 'merge' strategy."
317368
)
318-
strategy = "merge"
319369

370+
merge_filter = None
320371
if self.incremental_predicates:
321372
dialect = self.dialect(context)
322-
incremental_kind_kwargs["merge_filter"] = exp.and_(
373+
merge_filter = exp.and_(
323374
*[
324375
d.parse_one(predicate, dialect=dialect)
325376
for predicate in self.incremental_predicates
@@ -329,7 +380,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
329380

330381
return IncrementalByUniqueKeyKind(
331382
unique_key=self.unique_key,
332-
disable_restatement=disable_restatement,
383+
merge_filter=merge_filter,
333384
**incremental_kind_kwargs,
334385
**incremental_by_kind_kwargs,
335386
)
@@ -339,7 +390,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
339390
)
340391
return IncrementalUnmanagedKind(
341392
insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES,
342-
disable_restatement=disable_restatement,
393+
disable_restatement=incremental_by_kind_kwargs["disable_restatement"],
343394
**incremental_kind_kwargs,
344395
)
345396
if materialization == Materialization.EPHEMERAL:
@@ -438,6 +489,9 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
438489
"interval_unit",
439490
"allow_partials",
440491
"physical_version",
492+
"start",
493+
# In microbatch models `begin` is the same as `start`
494+
"begin",
441495
}
442496

443497
def to_sqlmesh(
@@ -583,12 +637,32 @@ def to_sqlmesh(
583637
# Set allow_partials to True for dbt models to preserve the original semantics.
584638
allow_partials = True
585639

640+
if kind.is_incremental:
641+
if self.batch_size and isinstance(self.batch_size, str):
642+
if "interval_unit" in model_kwargs:
643+
get_console().log_warning(
644+
f"Both 'interval_unit' and 'batch_size' are set for model '{self.canonical_name(context)}'. 'interval_unit' will be used."
645+
)
646+
else:
647+
model_kwargs["interval_unit"] = self.batch_size
648+
self.batch_size = None
649+
if begin := model_kwargs.pop("begin", None):
650+
if "start" in model_kwargs:
651+
get_console().log_warning(
652+
f"Both 'begin' and 'start' are set for model '{self.canonical_name(context)}'. 'start' will be used."
653+
)
654+
else:
655+
model_kwargs["start"] = begin
656+
657+
model_kwargs["start"] = model_kwargs.get(
658+
"start", context.sqlmesh_config.model_defaults.start
659+
)
660+
586661
model = create_sql_model(
587662
self.canonical_name(context),
588663
query,
589664
dialect=model_dialect,
590665
kind=kind,
591-
start=self.start or context.sqlmesh_config.model_defaults.start,
592666
audit_definitions=audit_definitions,
593667
# This ensures that we bypass query rendering that would otherwise be required to extract additional
594668
# dependencies from the model's SQL.

0 commit comments

Comments
 (0)