Skip to content

Commit c97cd15

Browse files
committed
feat: dbt microbatch ref filter support
1 parent 58e3eca commit c97cd15

File tree

6 files changed

+180
-1
lines changed

6 files changed

+180
-1
lines changed

sqlmesh/core/renderer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ def _resolve_table(table: str | exp.Table) -> str:
214214
dialect=self._dialect, identify=True, comments=False
215215
)
216216

217+
all_refs = list(self._jinja_macro_registry.global_objs.get("sources", {}).values()) + list( # type: ignore
218+
self._jinja_macro_registry.global_objs.get("refs", {}).values() # type: ignore
219+
)
220+
for ref in all_refs:
221+
if ref.event_time_filter:
222+
ref.event_time_filter["start"] = render_kwargs["start_tstz"]
223+
ref.event_time_filter["end"] = render_kwargs["end_tstz"]
217224
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
218225

219226
expressions = []

sqlmesh/dbt/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,25 @@
22
create_builtin_filters as create_builtin_filters,
33
create_builtin_globals as create_builtin_globals,
44
)
5+
from sqlmesh.dbt.util import DBT_VERSION
6+
7+
8+
if DBT_VERSION >= (1, 9, 0):
9+
from dbt.adapters.base.relation import BaseRelation, EventTimeFilter
10+
11+
def _render_event_time_filtered_inclusive(
12+
self: BaseRelation, event_time_filter: EventTimeFilter
13+
) -> str:
14+
"""
15+
Returns "" if start and end are both None
16+
"""
17+
filter = ""
18+
if event_time_filter.start and event_time_filter.end:
19+
filter = f"{event_time_filter.field_name} BETWEEN '{event_time_filter.start}' and '{event_time_filter.end}'"
20+
elif event_time_filter.start:
21+
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'"
22+
elif event_time_filter.end:
23+
filter = f"{event_time_filter.field_name} <= '{event_time_filter.end}'"
24+
return filter
25+
26+
BaseRelation._render_event_time_filtered = _render_event_time_filtered_inclusive # type: ignore

sqlmesh/dbt/basemodel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from sqlmesh.dbt.relation import Policy, RelationType
3030
from sqlmesh.dbt.test import TestConfig
31+
from sqlmesh.dbt.util import DBT_VERSION
3132
from sqlmesh.utils import AttributeDict
3233
from sqlmesh.utils.errors import ConfigError
3334
from sqlmesh.utils.pydantic import field_validator
@@ -130,6 +131,7 @@ class BaseModelConfig(GeneralConfig):
130131
grants: t.Dict[str, t.List[str]] = {}
131132
columns: t.Dict[str, ColumnConfig] = {}
132133
quoting: t.Dict[str, t.Optional[bool]] = {}
134+
event_time: t.Optional[str] = None
133135

134136
version: t.Optional[int] = None
135137
latest_version: t.Optional[int] = None
@@ -222,13 +224,20 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
222224
else:
223225
relation_type = RelationType.Table
224226

227+
extras = {}
228+
if DBT_VERSION >= (1, 9, 0) and self.event_time:
229+
extras["event_time_filter"] = {
230+
"field_name": self.event_time,
231+
}
232+
225233
return AttributeDict(
226234
{
227235
"database": self.database,
228236
"schema": self.table_schema,
229237
"identifier": self.table_name,
230238
"type": relation_type.value,
231239
"quote_policy": AttributeDict(self.quoting),
240+
**extras,
232241
}
233242
)
234243

sqlmesh/dbt/source.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlmesh.dbt.column import ColumnConfig
99
from sqlmesh.dbt.common import GeneralConfig
1010
from sqlmesh.dbt.relation import RelationType
11+
from sqlmesh.dbt.util import DBT_VERSION
1112
from sqlmesh.utils import AttributeDict
1213
from sqlmesh.utils.errors import ConfigError
1314

@@ -46,6 +47,7 @@ class SourceConfig(GeneralConfig):
4647
external: t.Optional[t.Dict[str, t.Any]] = {}
4748
source_meta: t.Optional[t.Dict[str, t.Any]] = {}
4849
columns: t.Dict[str, ColumnConfig] = {}
50+
event_time: t.Optional[str] = None
4951

5052
_canonical_name: t.Optional[str] = None
5153

@@ -94,6 +96,11 @@ def relation_info(self) -> AttributeDict:
9496
if external_location:
9597
extras["external"] = external_location.replace("{name}", self.table_name)
9698

99+
if DBT_VERSION >= (1, 9, 0) and self.event_time:
100+
extras["event_time_filter"] = {
101+
"field_name": self.event_time,
102+
}
103+
97104
return AttributeDict(
98105
{
99106
"database": self.database,

tests/core/test_snapshot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,9 @@ def test_fingerprint_jinja_macros_global_objs(model: Model, global_obj_key: str)
10781078
)
10791079
fingerprint = fingerprint_from_node(model, nodes={})
10801080
model = model.copy()
1081-
model.jinja_macros.global_objs[global_obj_key] = AttributeDict({"test": "test"})
1081+
model.jinja_macros.global_objs[global_obj_key] = AttributeDict(
1082+
{"test": AttributeDict({"test": "test"})}
1083+
)
10821084
updated_fingerprint = fingerprint_from_node(model, nodes={})
10831085
assert updated_fingerprint.data_hash != fingerprint.data_hash
10841086
assert updated_fingerprint.metadata_hash == fingerprint.metadata_hash

tests/dbt/test_model.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,135 @@ def test_load_microbatch_required_only(
301301
)
302302
assert model.kind.batch_size == 1
303303
assert model.depends_on_self is False
304+
305+
306+
@pytest.mark.slow
307+
def test_load_microbatch_with_ref(
308+
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
309+
) -> None:
310+
yaml = YAML()
311+
project_dir, model_dir = create_empty_project()
312+
source_schema = {
313+
"version": 2,
314+
"sources": [
315+
{
316+
"name": "my_source",
317+
"tables": [{"name": "my_table", "config": {"event_time": "ds"}}],
318+
}
319+
],
320+
}
321+
source_schema_file = model_dir / "source_schema.yml"
322+
with open(source_schema_file, "w", encoding="utf-8") as f:
323+
yaml.dump(source_schema, f)
324+
# add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it
325+
microbatch_contents = """
326+
{{
327+
config(
328+
materialized='incremental',
329+
incremental_strategy='microbatch',
330+
event_time='ds',
331+
begin='2020-01-01',
332+
batch_size='day'
333+
)
334+
}}
335+
336+
SELECT cola, ds FROM {{ source('my_source', 'my_table') }}
337+
"""
338+
microbatch_model_file = model_dir / "microbatch.sql"
339+
with open(microbatch_model_file, "w", encoding="utf-8") as f:
340+
f.write(microbatch_contents)
341+
342+
microbatch_two_contents = """
343+
{{
344+
config(
345+
materialized='incremental',
346+
incremental_strategy='microbatch',
347+
event_time='ds',
348+
begin='2020-01-05',
349+
batch_size='day'
350+
)
351+
}}
352+
353+
SELECT cola, ds FROM {{ ref('microbatch') }}
354+
"""
355+
microbatch_two_model_file = model_dir / "microbatch_two.sql"
356+
with open(microbatch_two_model_file, "w", encoding="utf-8") as f:
357+
f.write(microbatch_two_contents)
358+
359+
microbatch_snapshot_fqn = '"local"."main"."microbatch"'
360+
microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"'
361+
context = Context(paths=project_dir)
362+
assert (
363+
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
364+
== 'SELECT "cola" AS "cola", "ds" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds" BETWEEN \'2025-01-01 00:00:00+00:00\' AND \'2025-01-10 23:59:59.999999+00:00\') AS "_q_0"'
365+
)
366+
assert (
367+
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
368+
== 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" <= \'2025-01-10 23:59:59.999999+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"'
369+
)
370+
371+
372+
@pytest.mark.slow
373+
def test_load_microbatch_with_ref_no_filter(
374+
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
375+
) -> None:
376+
yaml = YAML()
377+
project_dir, model_dir = create_empty_project()
378+
source_schema = {
379+
"version": 2,
380+
"sources": [
381+
{
382+
"name": "my_source",
383+
"tables": [{"name": "my_table", "config": {"event_time": "ds"}}],
384+
}
385+
],
386+
}
387+
source_schema_file = model_dir / "source_schema.yml"
388+
with open(source_schema_file, "w", encoding="utf-8") as f:
389+
yaml.dump(source_schema, f)
390+
# add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it
391+
microbatch_contents = """
392+
{{
393+
config(
394+
materialized='incremental',
395+
incremental_strategy='microbatch',
396+
event_time='ds',
397+
begin='2020-01-01',
398+
batch_size='day'
399+
)
400+
}}
401+
402+
SELECT cola, ds FROM {{ source('my_source', 'my_table').render() }}
403+
"""
404+
microbatch_model_file = model_dir / "microbatch.sql"
405+
with open(microbatch_model_file, "w", encoding="utf-8") as f:
406+
f.write(microbatch_contents)
407+
408+
microbatch_two_contents = """
409+
{{
410+
config(
411+
materialized='incremental',
412+
incremental_strategy='microbatch',
413+
event_time='ds',
414+
begin='2020-01-01',
415+
batch_size='day'
416+
)
417+
}}
418+
419+
SELECT cola, ds FROM {{ ref('microbatch').render() }}
420+
"""
421+
microbatch_two_model_file = model_dir / "microbatch_two.sql"
422+
with open(microbatch_two_model_file, "w", encoding="utf-8") as f:
423+
f.write(microbatch_two_contents)
424+
425+
microbatch_snapshot_fqn = '"local"."main"."microbatch"'
426+
microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"'
427+
context = Context(paths=project_dir)
428+
assert (
429+
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
430+
== 'SELECT "cola" AS "cola", "ds" AS "ds" FROM "local"."my_source"."my_table" AS "my_table"'
431+
)
432+
assert (
433+
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
434+
== 'SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch"'
435+
)

0 commit comments

Comments
 (0)