From 8f6e2aa76af30295cb6b8f1d0ffd14ba7965a13c Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 3 Sep 2025 13:49:37 -0700 Subject: [PATCH 1/2] feat: dbt microbatch ref filter support --- sqlmesh/core/renderer.py | 7 ++ sqlmesh/dbt/__init__.py | 22 ++++++ sqlmesh/dbt/basemodel.py | 9 +++ sqlmesh/dbt/source.py | 7 ++ tests/core/test_snapshot.py | 4 +- tests/dbt/test_model.py | 132 ++++++++++++++++++++++++++++++++++++ 6 files changed, 180 insertions(+), 1 deletion(-) diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index e1d7bf4bcf..5eea10577b 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -214,6 +214,13 @@ def _resolve_table(table: str | exp.Table) -> str: dialect=self._dialect, identify=True, comments=False ) + all_refs = list(self._jinja_macro_registry.global_objs.get("sources", {}).values()) + list( # type: ignore + self._jinja_macro_registry.global_objs.get("refs", {}).values() # type: ignore + ) + for ref in all_refs: + if ref.event_time_filter: + ref.event_time_filter["start"] = render_kwargs["start_tstz"] + ref.event_time_filter["end"] = render_kwargs["end_tstz"] jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs) expressions = [] diff --git a/sqlmesh/dbt/__init__.py b/sqlmesh/dbt/__init__.py index 690b1f5289..7c1c2ba0d4 100644 --- a/sqlmesh/dbt/__init__.py +++ b/sqlmesh/dbt/__init__.py @@ -2,3 +2,25 @@ create_builtin_filters as create_builtin_filters, create_builtin_globals as create_builtin_globals, ) +from sqlmesh.dbt.util import DBT_VERSION + + +if DBT_VERSION >= (1, 9, 0): + from dbt.adapters.base.relation import BaseRelation, EventTimeFilter + + def _render_event_time_filtered_inclusive( + self: BaseRelation, event_time_filter: EventTimeFilter + ) -> str: + """ + Returns "" if start and end are both None + """ + filter = "" + if event_time_filter.start and event_time_filter.end: + filter = f"{event_time_filter.field_name} BETWEEN '{event_time_filter.start}' and '{event_time_filter.end}'" + elif event_time_filter.start: + filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'" + elif event_time_filter.end: + filter = f"{event_time_filter.field_name} <= '{event_time_filter.end}'" + return filter + + BaseRelation._render_event_time_filtered = _render_event_time_filtered_inclusive # type: ignore diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 548718cf89..a68a6ed598 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -28,6 +28,7 @@ ) from sqlmesh.dbt.relation import Policy, RelationType from sqlmesh.dbt.test import TestConfig +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator @@ -130,6 +131,7 @@ class BaseModelConfig(GeneralConfig): grants: t.Dict[str, t.List[str]] = {} columns: t.Dict[str, ColumnConfig] = {} quoting: t.Dict[str, t.Optional[bool]] = {} + event_time: t.Optional[str] = None version: t.Optional[int] = None latest_version: t.Optional[int] = None @@ -222,6 +224,12 @@ def relation_info(self) -> AttributeDict[str, t.Any]: else: relation_type = RelationType.Table + extras = {} + if DBT_VERSION >= (1, 9, 0) and self.event_time: + extras["event_time_filter"] = { + "field_name": self.event_time, + } + return AttributeDict( { "database": self.database, @@ -229,6 +237,7 @@ def relation_info(self) -> AttributeDict[str, t.Any]: "identifier": self.table_name, "type": relation_type.value, "quote_policy": AttributeDict(self.quoting), + **extras, } ) diff --git a/sqlmesh/dbt/source.py b/sqlmesh/dbt/source.py index 5b182c9a9a..76ee682e77 100644 --- a/sqlmesh/dbt/source.py +++ b/sqlmesh/dbt/source.py @@ -8,6 +8,7 @@ from sqlmesh.dbt.column import ColumnConfig from sqlmesh.dbt.common import GeneralConfig from sqlmesh.dbt.relation import RelationType +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict from sqlmesh.utils.errors import ConfigError @@ -46,6 +47,7 @@ class SourceConfig(GeneralConfig): external: t.Optional[t.Dict[str, t.Any]] = {} source_meta: t.Optional[t.Dict[str, t.Any]] = {} columns: t.Dict[str, ColumnConfig] = {} + event_time: t.Optional[str] = None _canonical_name: t.Optional[str] = None @@ -94,6 +96,11 @@ def relation_info(self) -> AttributeDict: if external_location: extras["external"] = external_location.replace("{name}", self.table_name) + if DBT_VERSION >= (1, 9, 0) and self.event_time: + extras["event_time_filter"] = { + "field_name": self.event_time, + } + return AttributeDict( { "database": self.database, diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index c37bd57d2e..eff3ad2b60 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -1079,7 +1079,9 @@ def test_fingerprint_jinja_macros_global_objs(model: Model, global_obj_key: str) ) fingerprint = fingerprint_from_node(model, nodes={}) model = model.copy() - model.jinja_macros.global_objs[global_obj_key] = AttributeDict({"test": "test"}) + model.jinja_macros.global_objs[global_obj_key] = AttributeDict( + {"test": AttributeDict({"test": "test"})} + ) updated_fingerprint = fingerprint_from_node(model, nodes={}) assert updated_fingerprint.data_hash != fingerprint.data_hash assert updated_fingerprint.metadata_hash == fingerprint.metadata_hash diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index 5037a69d65..dbfabd2d6d 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -447,3 +447,135 @@ def test_load_deprecated_incremental_time_column( "Using `time_column` on a model with incremental_strategy 'delete+insert' has been deprecated. Please use `incremental_by_time_range` instead in model 'main.incremental_time_range'." in caplog.text ) + + +@pytest.mark.slow +def test_load_microbatch_with_ref( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + yaml = YAML() + project_dir, model_dir = create_empty_project() + source_schema = { + "version": 2, + "sources": [ + { + "name": "my_source", + "tables": [{"name": "my_table", "config": {"event_time": "ds"}}], + } + ], + } + source_schema_file = model_dir / "source_schema.yml" + with open(source_schema_file, "w", encoding="utf-8") as f: + yaml.dump(source_schema, f) + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ source('my_source', 'my_table') }} + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + microbatch_two_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-05', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ ref('microbatch') }} + """ + microbatch_two_model_file = model_dir / "microbatch_two.sql" + with open(microbatch_two_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_two_contents) + + microbatch_snapshot_fqn = '"local"."main"."microbatch"' + microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"' + context = Context(paths=project_dir) + assert ( + context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "cola" AS "cola", "ds" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds" BETWEEN \'2025-01-01 00:00:00+00:00\' AND \'2025-01-10 23:59:59.999999+00:00\') AS "_q_0"' + ) + assert ( + context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" <= \'2025-01-10 23:59:59.999999+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"' + ) + + +@pytest.mark.slow +def test_load_microbatch_with_ref_no_filter( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + yaml = YAML() + project_dir, model_dir = create_empty_project() + source_schema = { + "version": 2, + "sources": [ + { + "name": "my_source", + "tables": [{"name": "my_table", "config": {"event_time": "ds"}}], + } + ], + } + source_schema_file = model_dir / "source_schema.yml" + with open(source_schema_file, "w", encoding="utf-8") as f: + yaml.dump(source_schema, f) + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ source('my_source', 'my_table').render() }} + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + microbatch_two_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ ref('microbatch').render() }} + """ + microbatch_two_model_file = model_dir / "microbatch_two.sql" + with open(microbatch_two_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_two_contents) + + microbatch_snapshot_fqn = '"local"."main"."microbatch"' + microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"' + context = Context(paths=project_dir) + assert ( + context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "cola" AS "cola", "ds" AS "ds" FROM "local"."my_source"."my_table" AS "my_table"' + ) + assert ( + context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch"' + ) From 1929b4f6c2676ae2afb6f659ea64f8d25fec1c68 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:02:07 -0700 Subject: [PATCH 2/2] make exclusive ts instead of between patch --- sqlmesh/core/renderer.py | 17 ++++++++++++++--- sqlmesh/dbt/__init__.py | 22 ---------------------- sqlmesh/utils/date.py | 7 +++++++ tests/dbt/test_model.py | 8 ++++---- 4 files changed, 25 insertions(+), 29 deletions(-) diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 5eea10577b..3502118e14 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -16,7 +16,14 @@ from sqlmesh.core import constants as c from sqlmesh.core import dialect as d from sqlmesh.core.macros import MacroEvaluator, RuntimeStage -from sqlmesh.utils.date import TimeLike, date_dict, make_inclusive, to_datetime +from sqlmesh.utils.date import ( + TimeLike, + date_dict, + make_inclusive, + to_datetime, + make_ts_exclusive, + to_tstz, +) from sqlmesh.utils.errors import ( ConfigError, ParsetimeAdapterCallError, @@ -214,13 +221,17 @@ def _resolve_table(table: str | exp.Table) -> str: dialect=self._dialect, identify=True, comments=False ) - all_refs = list(self._jinja_macro_registry.global_objs.get("sources", {}).values()) + list( # type: ignore + all_refs = list( + self._jinja_macro_registry.global_objs.get("sources", {}).values() # type: ignore + ) + list( self._jinja_macro_registry.global_objs.get("refs", {}).values() # type: ignore ) for ref in all_refs: if ref.event_time_filter: ref.event_time_filter["start"] = render_kwargs["start_tstz"] - ref.event_time_filter["end"] = render_kwargs["end_tstz"] + ref.event_time_filter["end"] = to_tstz( + make_ts_exclusive(render_kwargs["end_tstz"], dialect=self._dialect) + ) jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs) expressions = [] diff --git a/sqlmesh/dbt/__init__.py b/sqlmesh/dbt/__init__.py index 7c1c2ba0d4..690b1f5289 100644 --- a/sqlmesh/dbt/__init__.py +++ b/sqlmesh/dbt/__init__.py @@ -2,25 +2,3 @@ create_builtin_filters as create_builtin_filters, create_builtin_globals as create_builtin_globals, ) -from sqlmesh.dbt.util import DBT_VERSION - - -if DBT_VERSION >= (1, 9, 0): - from dbt.adapters.base.relation import BaseRelation, EventTimeFilter - - def _render_event_time_filtered_inclusive( - self: BaseRelation, event_time_filter: EventTimeFilter - ) -> str: - """ - Returns "" if start and end are both None - """ - filter = "" - if event_time_filter.start and event_time_filter.end: - filter = f"{event_time_filter.field_name} BETWEEN '{event_time_filter.start}' and '{event_time_filter.end}'" - elif event_time_filter.start: - filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'" - elif event_time_filter.end: - filter = f"{event_time_filter.field_name} <= '{event_time_filter.end}'" - return filter - - BaseRelation._render_event_time_filtered = _render_event_time_filtered_inclusive # type: ignore diff --git a/sqlmesh/utils/date.py b/sqlmesh/utils/date.py index 53a53cd62a..931cebf535 100644 --- a/sqlmesh/utils/date.py +++ b/sqlmesh/utils/date.py @@ -343,6 +343,13 @@ def make_exclusive(time: TimeLike) -> datetime: return dt +def make_ts_exclusive(time: TimeLike, dialect: DialectType) -> datetime: + ts = to_datetime(time) + if dialect == "tsql": + return to_utc_timestamp(ts) - pd.Timedelta(1, unit="ns") + return ts + timedelta(microseconds=1) + + def to_utc_timestamp(time: datetime) -> pd.Timestamp: import pandas as pd diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index dbfabd2d6d..14c042422e 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -460,7 +460,7 @@ def test_load_microbatch_with_ref( "sources": [ { "name": "my_source", - "tables": [{"name": "my_table", "config": {"event_time": "ds"}}], + "tables": [{"name": "my_table", "config": {"event_time": "ds_source"}}], } ], } @@ -479,7 +479,7 @@ def test_load_microbatch_with_ref( ) }} - SELECT cola, ds FROM {{ source('my_source', 'my_table') }} + SELECT cola, ds_source as ds FROM {{ source('my_source', 'my_table') }} """ microbatch_model_file = model_dir / "microbatch.sql" with open(microbatch_model_file, "w", encoding="utf-8") as f: @@ -507,11 +507,11 @@ def test_load_microbatch_with_ref( context = Context(paths=project_dir) assert ( context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() - == 'SELECT "cola" AS "cola", "ds" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds" BETWEEN \'2025-01-01 00:00:00+00:00\' AND \'2025-01-10 23:59:59.999999+00:00\') AS "_q_0"' + == 'SELECT "cola" AS "cola", "ds_source" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds_source" >= \'2025-01-01 00:00:00+00:00\' AND "ds_source" < \'2025-01-11 00:00:00+00:00\') AS "_q_0"' ) assert ( context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() - == 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" <= \'2025-01-10 23:59:59.999999+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"' + == 'SELECT "_q_0"."cola" AS "cola", "_q_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" < \'2025-01-11 00:00:00+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_q_0"' )