Skip to content

Commit 8f6e2aa

Browse files
committed
feat: dbt microbatch ref filter support
1 parent a6468f5 commit 8f6e2aa

File tree

6 files changed

+180
-1
lines changed

6 files changed

+180
-1
lines changed

sqlmesh/core/renderer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ def _resolve_table(table: str | exp.Table) -> str:
214214
dialect=self._dialect, identify=True, comments=False
215215
)
216216

217+
all_refs = list(self._jinja_macro_registry.global_objs.get("sources", {}).values()) + list( # type: ignore
218+
self._jinja_macro_registry.global_objs.get("refs", {}).values() # type: ignore
219+
)
220+
for ref in all_refs:
221+
if ref.event_time_filter:
222+
ref.event_time_filter["start"] = render_kwargs["start_tstz"]
223+
ref.event_time_filter["end"] = render_kwargs["end_tstz"]
217224
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
218225

219226
expressions = []

sqlmesh/dbt/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,25 @@
22
create_builtin_filters as create_builtin_filters,
33
create_builtin_globals as create_builtin_globals,
44
)
5+
from sqlmesh.dbt.util import DBT_VERSION
6+
7+
8+
if DBT_VERSION >= (1, 9, 0):
9+
from dbt.adapters.base.relation import BaseRelation, EventTimeFilter
10+
11+
def _render_event_time_filtered_inclusive(
12+
self: BaseRelation, event_time_filter: EventTimeFilter
13+
) -> str:
14+
"""
15+
Returns "" if start and end are both None
16+
"""
17+
filter = ""
18+
if event_time_filter.start and event_time_filter.end:
19+
filter = f"{event_time_filter.field_name} BETWEEN '{event_time_filter.start}' and '{event_time_filter.end}'"
20+
elif event_time_filter.start:
21+
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'"
22+
elif event_time_filter.end:
23+
filter = f"{event_time_filter.field_name} <= '{event_time_filter.end}'"
24+
return filter
25+
26+
BaseRelation._render_event_time_filtered = _render_event_time_filtered_inclusive # type: ignore

sqlmesh/dbt/basemodel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from sqlmesh.dbt.relation import Policy, RelationType
3030
from sqlmesh.dbt.test import TestConfig
31+
from sqlmesh.dbt.util import DBT_VERSION
3132
from sqlmesh.utils import AttributeDict
3233
from sqlmesh.utils.errors import ConfigError
3334
from sqlmesh.utils.pydantic import field_validator
@@ -130,6 +131,7 @@ class BaseModelConfig(GeneralConfig):
130131
grants: t.Dict[str, t.List[str]] = {}
131132
columns: t.Dict[str, ColumnConfig] = {}
132133
quoting: t.Dict[str, t.Optional[bool]] = {}
134+
event_time: t.Optional[str] = None
133135

134136
version: t.Optional[int] = None
135137
latest_version: t.Optional[int] = None
@@ -222,13 +224,20 @@ def relation_info(self) -> AttributeDict[str, t.Any]:
222224
else:
223225
relation_type = RelationType.Table
224226

227+
extras = {}
228+
if DBT_VERSION >= (1, 9, 0) and self.event_time:
229+
extras["event_time_filter"] = {
230+
"field_name": self.event_time,
231+
}
232+
225233
return AttributeDict(
226234
{
227235
"database": self.database,
228236
"schema": self.table_schema,
229237
"identifier": self.table_name,
230238
"type": relation_type.value,
231239
"quote_policy": AttributeDict(self.quoting),
240+
**extras,
232241
}
233242
)
234243

sqlmesh/dbt/source.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlmesh.dbt.column import ColumnConfig
99
from sqlmesh.dbt.common import GeneralConfig
1010
from sqlmesh.dbt.relation import RelationType
11+
from sqlmesh.dbt.util import DBT_VERSION
1112
from sqlmesh.utils import AttributeDict
1213
from sqlmesh.utils.errors import ConfigError
1314

@@ -46,6 +47,7 @@ class SourceConfig(GeneralConfig):
4647
external: t.Optional[t.Dict[str, t.Any]] = {}
4748
source_meta: t.Optional[t.Dict[str, t.Any]] = {}
4849
columns: t.Dict[str, ColumnConfig] = {}
50+
event_time: t.Optional[str] = None
4951

5052
_canonical_name: t.Optional[str] = None
5153

@@ -94,6 +96,11 @@ def relation_info(self) -> AttributeDict:
9496
if external_location:
9597
extras["external"] = external_location.replace("{name}", self.table_name)
9698

99+
if DBT_VERSION >= (1, 9, 0) and self.event_time:
100+
extras["event_time_filter"] = {
101+
"field_name": self.event_time,
102+
}
103+
97104
return AttributeDict(
98105
{
99106
"database": self.database,

tests/core/test_snapshot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,9 @@ def test_fingerprint_jinja_macros_global_objs(model: Model, global_obj_key: str)
10791079
)
10801080
fingerprint = fingerprint_from_node(model, nodes={})
10811081
model = model.copy()
1082-
model.jinja_macros.global_objs[global_obj_key] = AttributeDict({"test": "test"})
1082+
model.jinja_macros.global_objs[global_obj_key] = AttributeDict(
1083+
{"test": AttributeDict({"test": "test"})}
1084+
)
10831085
updated_fingerprint = fingerprint_from_node(model, nodes={})
10841086
assert updated_fingerprint.data_hash != fingerprint.data_hash
10851087
assert updated_fingerprint.metadata_hash == fingerprint.metadata_hash

tests/dbt/test_model.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,135 @@ def test_load_deprecated_incremental_time_column(
447447
"Using `time_column` on a model with incremental_strategy 'delete+insert' has been deprecated. Please use `incremental_by_time_range` instead in model 'main.incremental_time_range'."
448448
in caplog.text
449449
)
450+
451+
452+
@pytest.mark.slow
453+
def test_load_microbatch_with_ref(
454+
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
455+
) -> None:
456+
yaml = YAML()
457+
project_dir, model_dir = create_empty_project()
458+
source_schema = {
459+
"version": 2,
460+
"sources": [
461+
{
462+
"name": "my_source",
463+
"tables": [{"name": "my_table", "config": {"event_time": "ds"}}],
464+
}
465+
],
466+
}
467+
source_schema_file = model_dir / "source_schema.yml"
468+
with open(source_schema_file, "w", encoding="utf-8") as f:
469+
yaml.dump(source_schema, f)
470+
# add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it
471+
microbatch_contents = """
472+
{{
473+
config(
474+
materialized='incremental',
475+
incremental_strategy='microbatch',
476+
event_time='ds',
477+
begin='2020-01-01',
478+
batch_size='day'
479+
)
480+
}}
481+
482+
SELECT cola, ds FROM {{ source('my_source', 'my_table') }}
483+
"""
484+
microbatch_model_file = model_dir / "microbatch.sql"
485+
with open(microbatch_model_file, "w", encoding="utf-8") as f:
486+
f.write(microbatch_contents)
487+
488+
microbatch_two_contents = """
489+
{{
490+
config(
491+
materialized='incremental',
492+
incremental_strategy='microbatch',
493+
event_time='ds',
494+
begin='2020-01-05',
495+
batch_size='day'
496+
)
497+
}}
498+
499+
SELECT cola, ds FROM {{ ref('microbatch') }}
500+
"""
501+
microbatch_two_model_file = model_dir / "microbatch_two.sql"
502+
with open(microbatch_two_model_file, "w", encoding="utf-8") as f:
503+
f.write(microbatch_two_contents)
504+
505+
microbatch_snapshot_fqn = '"local"."main"."microbatch"'
506+
microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"'
507+
context = Context(paths=project_dir)
508+
assert (
509+
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
510+
== 'SELECT "cola" AS "cola", "ds" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds" BETWEEN \'2025-01-01 00:00:00+00:00\' AND \'2025-01-10 23:59:59.999999+00:00\') AS "_q_0"'
511+
)
512+
assert (
513+
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
514+
== 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" <= \'2025-01-10 23:59:59.999999+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"'
515+
)
516+
517+
518+
@pytest.mark.slow
519+
def test_load_microbatch_with_ref_no_filter(
520+
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project
521+
) -> None:
522+
yaml = YAML()
523+
project_dir, model_dir = create_empty_project()
524+
source_schema = {
525+
"version": 2,
526+
"sources": [
527+
{
528+
"name": "my_source",
529+
"tables": [{"name": "my_table", "config": {"event_time": "ds"}}],
530+
}
531+
],
532+
}
533+
source_schema_file = model_dir / "source_schema.yml"
534+
with open(source_schema_file, "w", encoding="utf-8") as f:
535+
yaml.dump(source_schema, f)
536+
# add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it
537+
microbatch_contents = """
538+
{{
539+
config(
540+
materialized='incremental',
541+
incremental_strategy='microbatch',
542+
event_time='ds',
543+
begin='2020-01-01',
544+
batch_size='day'
545+
)
546+
}}
547+
548+
SELECT cola, ds FROM {{ source('my_source', 'my_table').render() }}
549+
"""
550+
microbatch_model_file = model_dir / "microbatch.sql"
551+
with open(microbatch_model_file, "w", encoding="utf-8") as f:
552+
f.write(microbatch_contents)
553+
554+
microbatch_two_contents = """
555+
{{
556+
config(
557+
materialized='incremental',
558+
incremental_strategy='microbatch',
559+
event_time='ds',
560+
begin='2020-01-01',
561+
batch_size='day'
562+
)
563+
}}
564+
565+
SELECT cola, ds FROM {{ ref('microbatch').render() }}
566+
"""
567+
microbatch_two_model_file = model_dir / "microbatch_two.sql"
568+
with open(microbatch_two_model_file, "w", encoding="utf-8") as f:
569+
f.write(microbatch_two_contents)
570+
571+
microbatch_snapshot_fqn = '"local"."main"."microbatch"'
572+
microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"'
573+
context = Context(paths=project_dir)
574+
assert (
575+
context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
576+
== 'SELECT "cola" AS "cola", "ds" AS "ds" FROM "local"."my_source"."my_table" AS "my_table"'
577+
)
578+
assert (
579+
context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql()
580+
== 'SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch"'
581+
)

0 commit comments

Comments
 (0)