From 1c954141b7b5051b11c8931cb3161e6e9b5203fd Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Mon, 18 Aug 2025 20:53:30 +0300 Subject: [PATCH 1/5] Feat(dbt): Add support for selected resources context variable --- sqlmesh/core/context.py | 9 ++++ sqlmesh/dbt/builtin.py | 31 ++++++++++++++ tests/dbt/test_transformation.py | 71 ++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 78a391d12f..74b2584d4d 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -117,6 +117,7 @@ run_tests, ) from sqlmesh.core.user import User +from sqlmesh.dbt.builtin import set_selected_resources from sqlmesh.utils import UniqueKeyDict, Verbosity from sqlmesh.utils.concurrency import concurrent_apply_to_values from sqlmesh.utils.dag import DAG @@ -1583,6 +1584,11 @@ def plan_builder( "Selector did not return any models. Please check your model selection and try again." ) + if self._project_type != c.NATIVE: + set_selected_resources( + models=model_selector.expand_model_selections(select_models or "*") + ) + snapshots = self._snapshots(models_override) context_diff = self._context_diff( environment or c.PROD, @@ -2482,6 +2488,9 @@ def _run( select_models, no_auto_upstream, snapshots.values() ) + if self._project_type != c.NATIVE: + set_selected_resources(models=select_models or set([s.name for s in snapshots.keys()])) + completion_status = scheduler.run( environment, start=start, diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 0503f1dc92..574748f141 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -545,6 +545,7 @@ def create_builtin_globals( "run_query": sql_execution.run_query, "statement": sql_execution.statement, "graph": adapter.graph, + "selected_resources": get_selected_resources(), } ) @@ -572,3 +573,33 @@ def _relation_info_to_relation( } ) return relation_type.create(**relation_info, quote_policy=quote_policy) + + +_selected_resources: t.List[str] = [] + + +def set_selected_resources( + models: t.Optional[t.Set[str]] = None, +) -> None: + global _selected_resources + resources = [] + + if models: + for model in models: + resources.append(dbt_model_id(model)) + + _selected_resources = sorted(resources) + + +def dbt_model_id(sqlmesh_model_name: str) -> str: + parts = [part.strip('"') for part in sqlmesh_model_name.split(".")] + return f"model.{parts[0]}.{parts[-1]}" + + +def get_selected_resources() -> t.List[str]: + return _selected_resources + + +def clear_selected_resources() -> None: + global _selected_resources + _selected_resources = [] diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 6779e196df..7ae4443009 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -45,6 +45,12 @@ from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation, Config from sqlmesh.dbt.common import Dependencies +from sqlmesh.dbt.builtin import ( + _relation_info_to_relation, + dbt_model_id, + clear_selected_resources, + get_selected_resources, +) from sqlmesh.dbt.column import ( ColumnConfig, column_descriptions_to_sqlmesh, @@ -2375,3 +2381,68 @@ def test_dynamic_var_names_in_macro(sushi_test_project: Project): ) converted_model = model_config.to_sqlmesh(context) assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore + + +def test_selected_resources_with_selectors(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + + # A plan with a specific model selection + clear_selected_resources() + sushi_context.plan_builder(select_models=["sushi.customers"]) + + selected = get_selected_resources() + assert "model.memory.customers" in selected + assert len(selected) == 1 + + # Plan without model selections + clear_selected_resources() + sushi_context.plan_builder() + selected = get_selected_resources() + assert sorted( + [ + "model.memory.customer_revenue_by_day", + "model.memory.customers", + "model.memory.items", + "model.memory.items_check_snapshot", + "model.memory.items_no_hard_delete_snapshot", + "model.memory.items_snapshot", + "model.memory.order_items", + "model.memory.orders", + "model.memory.simple_model_a", + "model.memory.simple_model_b", + "model.memory.top_waiters", + "model.memory.waiter_as_customer_by_day", + "model.memory.waiter_names", + "model.memory.waiter_revenue_by_day_v1", + "model.memory.waiter_revenue_by_day_v2", + "model.memory.waiters", + ] + ) == sorted(selected) + + # Test with downstream models as well + clear_selected_resources() + sushi_context.plan_builder(select_models=["sushi.customers+"]) + selected = get_selected_resources() + assert sorted(["model.memory.customers", "model.memory.waiter_as_customer_by_day"]) == sorted( + selected + ) + + # Test wildcard selection + clear_selected_resources() + sushi_context.plan_builder(select_models=["sushi.waiter_*"]) + selected = get_selected_resources() + assert sorted( + [ + "model.memory.waiter_as_customer_by_day", + "model.memory.waiter_names", + "model.memory.waiter_revenue_by_day_v1", + "model.memory.waiter_revenue_by_day_v2", + ] + ) == sorted(selected) + clear_selected_resources() + + +def test_dbt_model_id_conversion(): + assert dbt_model_id("jaffle_shop.main.customers") == "model.jaffle_shop.customers" + assert dbt_model_id("jaffle_shop.main.orders") == "model.jaffle_shop.orders" + assert dbt_model_id('"jaffle_shop"."customers"') == "model.jaffle_shop.customers" From 8a9471cb7ba0de80be9ab69e29dd558f38ba0dd7 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 28 Aug 2025 20:45:11 +0300 Subject: [PATCH 2/5] pr feedback --- sqlmesh/core/context.py | 10 +-- sqlmesh/core/environment.py | 2 + sqlmesh/core/plan/builder.py | 3 + sqlmesh/core/plan/definition.py | 4 + sqlmesh/core/plan/evaluator.py | 3 + sqlmesh/core/scheduler.py | 5 ++ sqlmesh/dbt/adapter.py | 8 ++ sqlmesh/dbt/builtin.py | 32 +------ tests/dbt/test_transformation.py | 149 ++++++++++++++++++------------- 9 files changed, 116 insertions(+), 100 deletions(-) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 74b2584d4d..36fb1b632d 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -117,7 +117,6 @@ run_tests, ) from sqlmesh.core.user import User -from sqlmesh.dbt.builtin import set_selected_resources from sqlmesh.utils import UniqueKeyDict, Verbosity from sqlmesh.utils.concurrency import concurrent_apply_to_values from sqlmesh.utils.dag import DAG @@ -1584,11 +1583,6 @@ def plan_builder( "Selector did not return any models. Please check your model selection and try again." ) - if self._project_type != c.NATIVE: - set_selected_resources( - models=model_selector.expand_model_selections(select_models or "*") - ) - snapshots = self._snapshots(models_override) context_diff = self._context_diff( environment or c.PROD, @@ -1683,6 +1677,7 @@ def plan_builder( end_override_per_model=max_interval_end_per_model, console=self.console, user_provided_flags=user_provided_flags, + selected_models=model_selector.expand_model_selections(select_models or "*"), explain=explain or False, ignore_cron=ignore_cron or False, ) @@ -2488,9 +2483,6 @@ def _run( select_models, no_auto_upstream, snapshots.values() ) - if self._project_type != c.NATIVE: - set_selected_resources(models=select_models or set([s.name for s in snapshots.keys()])) - completion_status = scheduler.run( environment, start=start, diff --git a/sqlmesh/core/environment.py b/sqlmesh/core/environment.py index 2a0d4f115d..4a1f417468 100644 --- a/sqlmesh/core/environment.py +++ b/sqlmesh/core/environment.py @@ -312,6 +312,7 @@ def execute_environment_statements( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, + selected_models: t.Optional[t.Set[str]] = None, ) -> None: try: rendered_expressions = [ @@ -327,6 +328,7 @@ def execute_environment_statements( execution_time=execution_time, environment_naming_info=environment_naming_info, engine_adapter=adapter, + selected_models=selected_models, ) ] except Exception as e: diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index a48812d16c..a84b3b60dc 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -129,6 +129,7 @@ def __init__( end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, console: t.Optional[PlanBuilderConsole] = None, user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None, + selected_models: t.Optional[t.Set[str]] = None, ): self._context_diff = context_diff self._no_gaps = no_gaps @@ -169,6 +170,7 @@ def __init__( self._console = console or get_console() self._choices: t.Dict[SnapshotId, SnapshotChangeCategory] = {} self._user_provided_flags = user_provided_flags + self._selected_models = selected_models self._explain = explain self._start = start @@ -347,6 +349,7 @@ def build(self) -> Plan: ensure_finalized_snapshots=self._ensure_finalized_snapshots, ignore_cron=self._ignore_cron, user_provided_flags=self._user_provided_flags, + selected_models=self._selected_models, ) self._latest_plan = plan return plan diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index 2f3ddb5990..d3fe0ef36b 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -70,6 +70,8 @@ class Plan(PydanticModel, frozen=True): execution_time_: t.Optional[TimeLike] = Field(default=None, alias="execution_time") user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None + selected_models: t.Optional[t.Set[str]] = None + """Models that have been selected for this plan (used for dbt selected_resouces)""" @cached_property def start(self) -> TimeLike: @@ -282,6 +284,7 @@ def to_evaluatable(self) -> EvaluatablePlan: }, environment_statements=self.context_diff.environment_statements, user_provided_flags=self.user_provided_flags, + selected_models=self.selected_models, ) @cached_property @@ -319,6 +322,7 @@ class EvaluatablePlan(PydanticModel): disabled_restatement_models: t.Set[str] environment_statements: t.Optional[t.List[EnvironmentStatements]] = None user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None + selected_models: t.Optional[t.Set[str]] = None def is_selected_for_backfill(self, model_fqn: str) -> bool: return self.models_to_backfill is None or model_fqn in self.models_to_backfill diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 298d18a042..03b0b64016 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -137,6 +137,7 @@ def visit_before_all_stage(self, stage: stages.BeforeAllStage, plan: Evaluatable start=plan.start, end=plan.end, execution_time=plan.execution_time, + selected_models=plan.selected_models, ) def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePlan) -> None: @@ -150,6 +151,7 @@ def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePl start=plan.start, end=plan.end, execution_time=plan.execution_time, + selected_models=plan.selected_models, ) def visit_create_snapshot_records_stage( @@ -257,6 +259,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla allow_destructive_snapshots=plan.allow_destructive_models, allow_additive_snapshots=plan.allow_additive_models, selected_snapshot_ids=stage.selected_snapshot_ids, + selected_models=plan.selected_models, ) if errors: raise PlanError("Plan application failed.") diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 44d6b14c10..411ddbf5b5 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -416,6 +416,7 @@ def run_merged_intervals( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + selected_models: t.Optional[t.Set[str]] = None, allow_additive_snapshots: t.Optional[t.Set[str]] = None, selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, @@ -472,6 +473,7 @@ def run_merged_intervals( start=start, end=end, execution_time=execution_time, + selected_models=selected_models, ) # We only need to create physical tables if the snapshot is not representative or if it @@ -533,6 +535,7 @@ def run_node(node: SchedulingUnit) -> None: allow_destructive_snapshots=allow_destructive_snapshots, allow_additive_snapshots=allow_additive_snapshots, target_table_exists=snapshot.snapshot_id not in snapshots_to_create, + selected_models=selected_models, ) evaluation_duration_ms = now_timestamp() - execution_start_ts @@ -602,6 +605,7 @@ def run_node(node: SchedulingUnit) -> None: start=start, end=end, execution_time=execution_time, + selected_models=selected_models, ) self.state_sync.recycle() @@ -808,6 +812,7 @@ def _run_or_audit( run_environment_statements=run_environment_statements, audit_only=audit_only, auto_restatement_triggers=auto_restatement_triggers, + selected_models=selected_snapshots or {s.name for s in merged_intervals}, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 236d4cee6b..05bfab7471 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -181,6 +181,10 @@ def graph(self) -> t.Any: } ) + @property + def selected_resources(self) -> t.List[str]: + return [] + class ParsetimeAdapter(BaseAdapter): def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: @@ -501,3 +505,7 @@ def _normalize(self, input_table: exp.Table) -> exp.Table: normalized_table.set("db", normalized_table.this) normalized_table.set("this", None) return normalized_table + + def _dbt_model_id(self, sqlmesh_model_name: str) -> str: + parts = [part.strip('"') for part in sqlmesh_model_name.split(".")] + return f"model.{parts[0]}.{parts[-1]}" diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 574748f141..8c36550e0b 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -545,7 +545,7 @@ def create_builtin_globals( "run_query": sql_execution.run_query, "statement": sql_execution.statement, "graph": adapter.graph, - "selected_resources": get_selected_resources(), + "selected_resources": adapter.selected_resources, } ) @@ -573,33 +573,3 @@ def _relation_info_to_relation( } ) return relation_type.create(**relation_info, quote_policy=quote_policy) - - -_selected_resources: t.List[str] = [] - - -def set_selected_resources( - models: t.Optional[t.Set[str]] = None, -) -> None: - global _selected_resources - resources = [] - - if models: - for model in models: - resources.append(dbt_model_id(model)) - - _selected_resources = sorted(resources) - - -def dbt_model_id(sqlmesh_model_name: str) -> str: - parts = [part.strip('"') for part in sqlmesh_model_name.split(".")] - return f"model.{parts[0]}.{parts[-1]}" - - -def get_selected_resources() -> t.List[str]: - return _selected_resources - - -def clear_selected_resources() -> None: - global _selected_resources - _selected_resources = [] diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 7ae4443009..b8a07d9439 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -5,6 +5,8 @@ import typing as t from pathlib import Path from unittest.mock import patch +from sqlmesh.dbt.adapter import RuntimeAdapter +from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.dbt.util import DBT_VERSION @@ -45,12 +47,7 @@ from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation, Config from sqlmesh.dbt.common import Dependencies -from sqlmesh.dbt.builtin import ( - _relation_info_to_relation, - dbt_model_id, - clear_selected_resources, - get_selected_resources, -) +from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig, column_descriptions_to_sqlmesh, @@ -2387,62 +2384,94 @@ def test_selected_resources_with_selectors(): sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) # A plan with a specific model selection - clear_selected_resources() - sushi_context.plan_builder(select_models=["sushi.customers"]) - - selected = get_selected_resources() - assert "model.memory.customers" in selected - assert len(selected) == 1 - - # Plan without model selections - clear_selected_resources() - sushi_context.plan_builder() - selected = get_selected_resources() - assert sorted( - [ - "model.memory.customer_revenue_by_day", - "model.memory.customers", - "model.memory.items", - "model.memory.items_check_snapshot", - "model.memory.items_no_hard_delete_snapshot", - "model.memory.items_snapshot", - "model.memory.order_items", - "model.memory.orders", - "model.memory.simple_model_a", - "model.memory.simple_model_b", - "model.memory.top_waiters", - "model.memory.waiter_as_customer_by_day", - "model.memory.waiter_names", - "model.memory.waiter_revenue_by_day_v1", - "model.memory.waiter_revenue_by_day_v2", - "model.memory.waiters", - ] - ) == sorted(selected) + plan_builder = sushi_context.plan_builder(select_models=["sushi.customers"]) + plan = plan_builder.build() + assert len(plan.selected_models) == 1 + selected_model = list(plan.selected_models)[0] + assert "customers" in selected_model + + # Plan without model selections should include all models + plan_builder = sushi_context.plan_builder() + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) > 10 + + # with downstream models should select customers and at least one downstream model + plan_builder = sushi_context.plan_builder(select_models=["sushi.customers+"]) + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) >= 2 + assert any("customers" in model for model in plan.selected_models) + + # Test wildcard selection + plan_builder = sushi_context.plan_builder(select_models=["sushi.waiter_*"]) + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) >= 4 + assert all("waiter" in model for model in plan.selected_models) - # Test with downstream models as well - clear_selected_resources() - sushi_context.plan_builder(select_models=["sushi.customers+"]) - selected = get_selected_resources() - assert sorted(["model.memory.customers", "model.memory.waiter_as_customer_by_day"]) == sorted( - selected + +@pytest.mark.xdist_group("dbt_manifest") +def test_selected_resources_context_variable( + sushi_test_project: Project, sushi_test_dbt_context: Context +): + context = sushi_test_project.context + + # should be empty list during parse time + direct_access = context.render("{{ selected_resources }}") + assert direct_access == "[]" + + # selected_resources is iterable and count items + test_jinja = """ + {%- set resources = [] -%} + {%- for resource in selected_resources -%} + {%- do resources.append(resource) -%} + {%- endfor -%} + {{ resources | length }} + """ + result = context.render(test_jinja) + assert result.strip() == "0" + + # selected_resources in conditions + test_condition = """ + {%- if selected_resources -%} + has_resources + {%- else -%} + no_resources + {%- endif -%} + """ + result = context.render(test_condition) + assert result.strip() == "no_resources" + + # Test 4: Test with runtime adapter (simulating runtime execution) + runtime_adapter = RuntimeAdapter( + engine_adapter=sushi_test_dbt_context.engine_adapter, + jinja_macros=JinjaMacroRegistry(), + jinja_globals={ + "selected_models": { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."items"', + }, + }, ) - # Test wildcard selection - clear_selected_resources() - sushi_context.plan_builder(select_models=["sushi.waiter_*"]) - selected = get_selected_resources() - assert sorted( - [ - "model.memory.waiter_as_customer_by_day", - "model.memory.waiter_names", - "model.memory.waiter_revenue_by_day_v1", - "model.memory.waiter_revenue_by_day_v2", - ] - ) == sorted(selected) - clear_selected_resources() + # it should return correct selected resources in dbt format + selected_resources = runtime_adapter.selected_resources + assert len(selected_resources) == 3 + assert "model.jaffle_shop.customers" in selected_resources + assert "model.jaffle_shop.orders" in selected_resources + assert "model.jaffle_shop.items" in selected_resources + + # check the jinja macros rendering + result = context.render("{{ selected_resources }}", selected_resources=selected_resources) + assert ( + result + == "['model.jaffle_shop.customers', 'model.jaffle_shop.items', 'model.jaffle_shop.orders']" + ) + result = context.render(test_jinja, selected_resources=selected_resources) + assert result.strip() == "3" -def test_dbt_model_id_conversion(): - assert dbt_model_id("jaffle_shop.main.customers") == "model.jaffle_shop.customers" - assert dbt_model_id("jaffle_shop.main.orders") == "model.jaffle_shop.orders" - assert dbt_model_id('"jaffle_shop"."customers"') == "model.jaffle_shop.customers" + result = context.render(test_condition, selected_resources=selected_resources) + assert result.strip() == "has_resources" From e42ad825b6f096a19e10277e47bdf08a9fc85328 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Fri, 29 Aug 2025 12:03:13 +0300 Subject: [PATCH 3/5] fix comments --- sqlmesh/core/plan/definition.py | 2 +- sqlmesh/dbt/adapter.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index d3fe0ef36b..aaf6ec5dc0 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -71,7 +71,7 @@ class Plan(PydanticModel, frozen=True): user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None selected_models: t.Optional[t.Set[str]] = None - """Models that have been selected for this plan (used for dbt selected_resouces)""" + """Models that have been selected for this plan (used for dbt selected_resources)""" @cached_property def start(self) -> TimeLike: diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 05bfab7471..1e49a168df 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -507,5 +507,6 @@ def _normalize(self, input_table: exp.Table) -> exp.Table: return normalized_table def _dbt_model_id(self, sqlmesh_model_name: str) -> str: + # Model prefix is needed to correspond to the key in the nodes within the dbt context variable parts = [part.strip('"') for part in sqlmesh_model_name.split(".")] return f"model.{parts[0]}.{parts[-1]}" From f5c48f0d206837bcd1eddf1381c6cf196ba613ba Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Mon, 8 Sep 2025 20:03:45 +0300 Subject: [PATCH 4/5] add the dbt node name --- sqlmesh/core/context.py | 6 +- sqlmesh/core/node.py | 1 + sqlmesh/core/scheduler.py | 2 +- sqlmesh/dbt/adapter.py | 9 --- sqlmesh/dbt/builtin.py | 2 +- sqlmesh/dbt/model.py | 1 + sqlmesh/dbt/seed.py | 1 + sqlmesh/migrations/v0097_add_node_name.py | 9 +++ tests/dbt/test_model.py | 79 ++++++++++++++++++++++- tests/dbt/test_transformation.py | 34 +++------- 10 files changed, 105 insertions(+), 39 deletions(-) create mode 100644 sqlmesh/migrations/v0097_add_node_name.py diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 36fb1b632d..e0f58994e1 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -1677,7 +1677,11 @@ def plan_builder( end_override_per_model=max_interval_end_per_model, console=self.console, user_provided_flags=user_provided_flags, - selected_models=model_selector.expand_model_selections(select_models or "*"), + selected_models={ + node_name + for model in model_selector.expand_model_selections(select_models or "*") + if (node_name := snapshots[model].node.node_name) + }, explain=explain or False, ignore_cron=ignore_cron or False, ) diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index ea2264f7fa..da5d3c04ed 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -199,6 +199,7 @@ class _Node(PydanticModel): interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None) tags: t.List[str] = [] stamp: t.Optional[str] = None + node_name: t.Optional[str] = None # dbt node name _path: t.Optional[Path] = None _data_hash: t.Optional[str] = None _metadata_hash: t.Optional[str] = None diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 411ddbf5b5..b9ddbe1462 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -812,7 +812,7 @@ def _run_or_audit( run_environment_statements=run_environment_statements, audit_only=audit_only, auto_restatement_triggers=auto_restatement_triggers, - selected_models=selected_snapshots or {s.name for s in merged_intervals}, + selected_models={s.node.node_name for s in merged_intervals if s.node.node_name}, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 1e49a168df..236d4cee6b 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -181,10 +181,6 @@ def graph(self) -> t.Any: } ) - @property - def selected_resources(self) -> t.List[str]: - return [] - class ParsetimeAdapter(BaseAdapter): def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: @@ -505,8 +501,3 @@ def _normalize(self, input_table: exp.Table) -> exp.Table: normalized_table.set("db", normalized_table.this) normalized_table.set("this", None) return normalized_table - - def _dbt_model_id(self, sqlmesh_model_name: str) -> str: - # Model prefix is needed to correspond to the key in the nodes within the dbt context variable - parts = [part.strip('"') for part in sqlmesh_model_name.split(".")] - return f"model.{parts[0]}.{parts[-1]}" diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 8c36550e0b..e284c11797 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -545,7 +545,7 @@ def create_builtin_globals( "run_query": sql_execution.run_query, "statement": sql_execution.statement, "graph": adapter.graph, - "selected_resources": adapter.selected_resources, + "selected_resources": list(jinja_globals.get("selected_models") or []), } ) diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index a4ebf93ae5..ad43b5a183 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -689,6 +689,7 @@ def to_sqlmesh( extract_dependencies_from_query=False, allow_partials=allow_partials, virtual_environment_mode=virtual_environment_mode, + node_name=self.node_name, **optional_kwargs, **model_kwargs, ) diff --git a/sqlmesh/dbt/seed.py b/sqlmesh/dbt/seed.py index 38cd635d91..fee676b2e2 100644 --- a/sqlmesh/dbt/seed.py +++ b/sqlmesh/dbt/seed.py @@ -92,6 +92,7 @@ def to_sqlmesh( audit_definitions=audit_definitions, virtual_environment_mode=virtual_environment_mode, start=self.start or context.sqlmesh_config.model_defaults.start, + node_name=self.node_name, **kwargs, ) diff --git a/sqlmesh/migrations/v0097_add_node_name.py b/sqlmesh/migrations/v0097_add_node_name.py new file mode 100644 index 0000000000..dc73db0801 --- /dev/null +++ b/sqlmesh/migrations/v0097_add_node_name.py @@ -0,0 +1,9 @@ +"""Add 'node_name' property to node definition.""" + + +def migrate_schemas(state_sync, **kwargs): # type: ignore + pass + + +def migrate_rows(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index bfc18144ef..2e816d6f8f 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -201,7 +201,7 @@ def test_load_microbatch_all_defined( concurrent_batches=true ) }} - + SELECT 1 as cola, '2025-01-01' as ds """ microbatch_model_file = model_dir / "microbatch.sql" @@ -633,3 +633,80 @@ def test_dbt_jinja_macro_undefined_variable_error(create_empty_project): assert "Failed to update model schemas" in error_message assert "Could not render jinja for" in error_message assert "Undefined macro/variable: 'columns' in macro: 'select_columns'" in error_message + + +@pytest.mark.slow +def test_node_name_populated_for_dbt_models(dbt_dummy_postgres_config: PostgresConfig) -> None: + model_config = ModelConfig( + name="test_model", + package_name="test_package", + sql="SELECT 1 as id", + database="test_db", + schema_="test_schema", + alias="test_model", + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = dbt_dummy_postgres_config + + # check after convert to SQLMesh model that node_name is populated correctly + sqlmesh_model = model_config.to_sqlmesh(context) + assert sqlmesh_model.node_name == "model.test_package.test_model" + + +@pytest.mark.slow +def test_load_model_dbt_node_name(tmp_path: Path) -> None: + yaml = YAML() + dbt_project_dir = tmp_path / "dbt" + dbt_project_dir.mkdir() + dbt_model_dir = dbt_project_dir / "models" + dbt_model_dir.mkdir() + + model_contents = "SELECT 1 as id, 'test' as name" + model_file = dbt_model_dir / "simple_model.sql" + with open(model_file, "w", encoding="utf-8") as f: + f.write(model_contents) + + dbt_project_config = { + "name": "test_project", + "version": "1.0.0", + "config-version": 2, + "profile": "test", + "model-paths": ["models"], + } + dbt_project_file = dbt_project_dir / "dbt_project.yml" + with open(dbt_project_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_project_config, f) + + sqlmesh_config = { + "model_defaults": { + "start": "2025-01-01", + } + } + sqlmesh_config_file = dbt_project_dir / "sqlmesh.yaml" + with open(sqlmesh_config_file, "w", encoding="utf-8") as f: + yaml.dump(sqlmesh_config, f) + + dbt_data_dir = tmp_path / "dbt_data" + dbt_data_dir.mkdir() + dbt_data_file = dbt_data_dir / "local.db" + dbt_profile_config = { + "test": { + "outputs": {"duckdb": {"type": "duckdb", "path": str(dbt_data_file)}}, + "target": "duckdb", + } + } + db_profile_file = dbt_project_dir / "profiles.yml" + with open(db_profile_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_profile_config, f) + + context = Context(paths=dbt_project_dir) + + # find the model by its sqlmesh fully qualified name + model_fqn = '"local"."main"."simple_model"' + assert model_fqn in context.snapshots + + # Verify that node_name is the equivalent dbt one + model = context.snapshots[model_fqn].model + assert model.node_name == "model.test_project.simple_model" diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index b8a07d9439..551c6cc16f 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -5,8 +5,6 @@ import typing as t from pathlib import Path from unittest.mock import patch -from sqlmesh.dbt.adapter import RuntimeAdapter -from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.dbt.util import DBT_VERSION @@ -2417,7 +2415,7 @@ def test_selected_resources_context_variable( ): context = sushi_test_project.context - # should be empty list during parse time + # empty selected resources direct_access = context.render("{{ selected_resources }}") assert direct_access == "[]" @@ -2443,32 +2441,16 @@ def test_selected_resources_context_variable( result = context.render(test_condition) assert result.strip() == "no_resources" - # Test 4: Test with runtime adapter (simulating runtime execution) - runtime_adapter = RuntimeAdapter( - engine_adapter=sushi_test_dbt_context.engine_adapter, - jinja_macros=JinjaMacroRegistry(), - jinja_globals={ - "selected_models": { - '"jaffle_shop"."main"."customers"', - '"jaffle_shop"."main"."orders"', - '"jaffle_shop"."main"."items"', - }, - }, - ) - - # it should return correct selected resources in dbt format - selected_resources = runtime_adapter.selected_resources - assert len(selected_resources) == 3 - assert "model.jaffle_shop.customers" in selected_resources - assert "model.jaffle_shop.orders" in selected_resources - assert "model.jaffle_shop.items" in selected_resources + # selected resources in dbt format + selected_resources = [ + "model.jaffle_shop.customers", + "model.jaffle_shop.items", + "model.jaffle_shop.orders", + ] # check the jinja macros rendering result = context.render("{{ selected_resources }}", selected_resources=selected_resources) - assert ( - result - == "['model.jaffle_shop.customers', 'model.jaffle_shop.items', 'model.jaffle_shop.orders']" - ) + assert result == selected_resources.__repr__() result = context.render(test_jinja, selected_resources=selected_resources) assert result.strip() == "3" From d1b2940695fa38775f32acd943da9892f47ae401 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:18:40 +0300 Subject: [PATCH 5/5] revise the property name to dbt_name --- sqlmesh/core/context.py | 4 ++-- sqlmesh/core/node.py | 2 +- sqlmesh/core/scheduler.py | 2 +- sqlmesh/dbt/model.py | 2 +- sqlmesh/dbt/seed.py | 2 +- .../{v0097_add_node_name.py => v0097_add_dbt_name_in_node.py} | 2 +- tests/dbt/test_model.py | 4 ++-- 7 files changed, 9 insertions(+), 9 deletions(-) rename sqlmesh/migrations/{v0097_add_node_name.py => v0097_add_dbt_name_in_node.py} (72%) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index e0f58994e1..0339f6506c 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -1678,9 +1678,9 @@ def plan_builder( console=self.console, user_provided_flags=user_provided_flags, selected_models={ - node_name + dbt_name for model in model_selector.expand_model_selections(select_models or "*") - if (node_name := snapshots[model].node.node_name) + if (dbt_name := snapshots[model].node.dbt_name) }, explain=explain or False, ignore_cron=ignore_cron or False, diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index da5d3c04ed..b04a59a39f 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -199,7 +199,7 @@ class _Node(PydanticModel): interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None) tags: t.List[str] = [] stamp: t.Optional[str] = None - node_name: t.Optional[str] = None # dbt node name + dbt_name: t.Optional[str] = None # dbt node name _path: t.Optional[Path] = None _data_hash: t.Optional[str] = None _metadata_hash: t.Optional[str] = None diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index b9ddbe1462..ec204927d4 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -812,7 +812,7 @@ def _run_or_audit( run_environment_statements=run_environment_statements, audit_only=audit_only, auto_restatement_triggers=auto_restatement_triggers, - selected_models={s.node.node_name for s in merged_intervals if s.node.node_name}, + selected_models={s.node.dbt_name for s in merged_intervals if s.node.dbt_name}, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index ad43b5a183..3d5da1beaa 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -689,7 +689,7 @@ def to_sqlmesh( extract_dependencies_from_query=False, allow_partials=allow_partials, virtual_environment_mode=virtual_environment_mode, - node_name=self.node_name, + dbt_name=self.node_name, **optional_kwargs, **model_kwargs, ) diff --git a/sqlmesh/dbt/seed.py b/sqlmesh/dbt/seed.py index fee676b2e2..d6ecc768f9 100644 --- a/sqlmesh/dbt/seed.py +++ b/sqlmesh/dbt/seed.py @@ -92,7 +92,7 @@ def to_sqlmesh( audit_definitions=audit_definitions, virtual_environment_mode=virtual_environment_mode, start=self.start or context.sqlmesh_config.model_defaults.start, - node_name=self.node_name, + dbt_name=self.node_name, **kwargs, ) diff --git a/sqlmesh/migrations/v0097_add_node_name.py b/sqlmesh/migrations/v0097_add_dbt_name_in_node.py similarity index 72% rename from sqlmesh/migrations/v0097_add_node_name.py rename to sqlmesh/migrations/v0097_add_dbt_name_in_node.py index dc73db0801..f8909e4430 100644 --- a/sqlmesh/migrations/v0097_add_node_name.py +++ b/sqlmesh/migrations/v0097_add_dbt_name_in_node.py @@ -1,4 +1,4 @@ -"""Add 'node_name' property to node definition.""" +"""Add 'dbt_name' property to node definition.""" def migrate_schemas(state_sync, **kwargs): # type: ignore diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index 2e816d6f8f..d3103d3681 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -652,7 +652,7 @@ def test_node_name_populated_for_dbt_models(dbt_dummy_postgres_config: PostgresC # check after convert to SQLMesh model that node_name is populated correctly sqlmesh_model = model_config.to_sqlmesh(context) - assert sqlmesh_model.node_name == "model.test_package.test_model" + assert sqlmesh_model.dbt_name == "model.test_package.test_model" @pytest.mark.slow @@ -709,4 +709,4 @@ def test_load_model_dbt_node_name(tmp_path: Path) -> None: # Verify that node_name is the equivalent dbt one model = context.snapshots[model_fqn].model - assert model.node_name == "model.test_project.simple_model" + assert model.dbt_name == "model.test_project.simple_model"