Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,11 @@ def plan_builder(
end_override_per_model=max_interval_end_per_model,
console=self.console,
user_provided_flags=user_provided_flags,
selected_models={
dbt_name
for model in model_selector.expand_model_selections(select_models or "*")
if (dbt_name := snapshots[model].node.dbt_name)
},
explain=explain or False,
ignore_cron=ignore_cron or False,
)
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def execute_environment_statements(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
execution_time: t.Optional[TimeLike] = None,
selected_models: t.Optional[t.Set[str]] = None,
) -> None:
try:
rendered_expressions = [
Expand All @@ -327,6 +328,7 @@ def execute_environment_statements(
execution_time=execution_time,
environment_naming_info=environment_naming_info,
engine_adapter=adapter,
selected_models=selected_models,
)
]
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class _Node(PydanticModel):
interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None)
tags: t.List[str] = []
stamp: t.Optional[str] = None
dbt_name: t.Optional[str] = None # dbt node name
_path: t.Optional[Path] = None
_data_hash: t.Optional[str] = None
_metadata_hash: t.Optional[str] = None
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
end_override_per_model: t.Optional[t.Dict[str, datetime]] = None,
console: t.Optional[PlanBuilderConsole] = None,
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None,
selected_models: t.Optional[t.Set[str]] = None,
):
self._context_diff = context_diff
self._no_gaps = no_gaps
Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(
self._console = console or get_console()
self._choices: t.Dict[SnapshotId, SnapshotChangeCategory] = {}
self._user_provided_flags = user_provided_flags
self._selected_models = selected_models
self._explain = explain

self._start = start
Expand Down Expand Up @@ -347,6 +349,7 @@ def build(self) -> Plan:
ensure_finalized_snapshots=self._ensure_finalized_snapshots,
ignore_cron=self._ignore_cron,
user_provided_flags=self._user_provided_flags,
selected_models=self._selected_models,
)
self._latest_plan = plan
return plan
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/plan/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class Plan(PydanticModel, frozen=True):
execution_time_: t.Optional[TimeLike] = Field(default=None, alias="execution_time")

user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None
selected_models: t.Optional[t.Set[str]] = None
"""Models that have been selected for this plan (used for dbt selected_resources)"""

@cached_property
def start(self) -> TimeLike:
Expand Down Expand Up @@ -282,6 +284,7 @@ def to_evaluatable(self) -> EvaluatablePlan:
},
environment_statements=self.context_diff.environment_statements,
user_provided_flags=self.user_provided_flags,
selected_models=self.selected_models,
)

@cached_property
Expand Down Expand Up @@ -319,6 +322,7 @@ class EvaluatablePlan(PydanticModel):
disabled_restatement_models: t.Set[str]
environment_statements: t.Optional[t.List[EnvironmentStatements]] = None
user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None
selected_models: t.Optional[t.Set[str]] = None

def is_selected_for_backfill(self, model_fqn: str) -> bool:
return self.models_to_backfill is None or model_fqn in self.models_to_backfill
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def visit_before_all_stage(self, stage: stages.BeforeAllStage, plan: Evaluatable
start=plan.start,
end=plan.end,
execution_time=plan.execution_time,
selected_models=plan.selected_models,
)

def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePlan) -> None:
Expand All @@ -150,6 +151,7 @@ def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePl
start=plan.start,
end=plan.end,
execution_time=plan.execution_time,
selected_models=plan.selected_models,
)

def visit_create_snapshot_records_stage(
Expand Down Expand Up @@ -257,6 +259,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
allow_destructive_snapshots=plan.allow_destructive_models,
allow_additive_snapshots=plan.allow_additive_models,
selected_snapshot_ids=stage.selected_snapshot_ids,
selected_models=plan.selected_models,
)
if errors:
raise PlanError("Plan application failed.")
Expand Down
5 changes: 5 additions & 0 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def run_merged_intervals(
start: t.Optional[TimeLike] = None,
end: t.Optional[TimeLike] = None,
allow_destructive_snapshots: t.Optional[t.Set[str]] = None,
selected_models: t.Optional[t.Set[str]] = None,
allow_additive_snapshots: t.Optional[t.Set[str]] = None,
selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None,
run_environment_statements: bool = False,
Expand Down Expand Up @@ -472,6 +473,7 @@ def run_merged_intervals(
start=start,
end=end,
execution_time=execution_time,
selected_models=selected_models,
)

# We only need to create physical tables if the snapshot is not representative or if it
Expand Down Expand Up @@ -533,6 +535,7 @@ def run_node(node: SchedulingUnit) -> None:
allow_destructive_snapshots=allow_destructive_snapshots,
allow_additive_snapshots=allow_additive_snapshots,
target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
selected_models=selected_models,
)

evaluation_duration_ms = now_timestamp() - execution_start_ts
Expand Down Expand Up @@ -602,6 +605,7 @@ def run_node(node: SchedulingUnit) -> None:
start=start,
end=end,
execution_time=execution_time,
selected_models=selected_models,
)

self.state_sync.recycle()
Expand Down Expand Up @@ -808,6 +812,7 @@ def _run_or_audit(
run_environment_statements=run_environment_statements,
audit_only=audit_only,
auto_restatement_triggers=auto_restatement_triggers,
selected_models={s.node.dbt_name for s in merged_intervals if s.node.dbt_name},
)

return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ def create_builtin_globals(
"run_query": sql_execution.run_query,
"statement": sql_execution.statement,
"graph": adapter.graph,
"selected_resources": list(jinja_globals.get("selected_models") or []),
}
)

Expand Down
1 change: 1 addition & 0 deletions sqlmesh/dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ def to_sqlmesh(
extract_dependencies_from_query=False,
allow_partials=allow_partials,
virtual_environment_mode=virtual_environment_mode,
dbt_name=self.node_name,
**optional_kwargs,
**model_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/dbt/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def to_sqlmesh(
audit_definitions=audit_definitions,
virtual_environment_mode=virtual_environment_mode,
start=self.start or context.sqlmesh_config.model_defaults.start,
dbt_name=self.node_name,
**kwargs,
)

Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/migrations/v0097_add_dbt_name_in_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Add 'dbt_name' property to node definition."""


def migrate_schemas(state_sync, **kwargs): # type: ignore
pass


def migrate_rows(state_sync, **kwargs): # type: ignore
pass
79 changes: 78 additions & 1 deletion tests/dbt/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_load_microbatch_all_defined(
concurrent_batches=true
)
}}

SELECT 1 as cola, '2025-01-01' as ds
"""
microbatch_model_file = model_dir / "microbatch.sql"
Expand Down Expand Up @@ -633,3 +633,80 @@ def test_dbt_jinja_macro_undefined_variable_error(create_empty_project):
assert "Failed to update model schemas" in error_message
assert "Could not render jinja for" in error_message
assert "Undefined macro/variable: 'columns' in macro: 'select_columns'" in error_message


@pytest.mark.slow
def test_node_name_populated_for_dbt_models(dbt_dummy_postgres_config: PostgresConfig) -> None:
model_config = ModelConfig(
name="test_model",
package_name="test_package",
sql="SELECT 1 as id",
database="test_db",
schema_="test_schema",
alias="test_model",
)

context = DbtContext()
context.project_name = "test_project"
context.target = dbt_dummy_postgres_config

# check after convert to SQLMesh model that node_name is populated correctly
sqlmesh_model = model_config.to_sqlmesh(context)
assert sqlmesh_model.dbt_name == "model.test_package.test_model"


@pytest.mark.slow
def test_load_model_dbt_node_name(tmp_path: Path) -> None:
yaml = YAML()
dbt_project_dir = tmp_path / "dbt"
dbt_project_dir.mkdir()
dbt_model_dir = dbt_project_dir / "models"
dbt_model_dir.mkdir()

model_contents = "SELECT 1 as id, 'test' as name"
model_file = dbt_model_dir / "simple_model.sql"
with open(model_file, "w", encoding="utf-8") as f:
f.write(model_contents)

dbt_project_config = {
"name": "test_project",
"version": "1.0.0",
"config-version": 2,
"profile": "test",
"model-paths": ["models"],
}
dbt_project_file = dbt_project_dir / "dbt_project.yml"
with open(dbt_project_file, "w", encoding="utf-8") as f:
yaml.dump(dbt_project_config, f)

sqlmesh_config = {
"model_defaults": {
"start": "2025-01-01",
}
}
sqlmesh_config_file = dbt_project_dir / "sqlmesh.yaml"
with open(sqlmesh_config_file, "w", encoding="utf-8") as f:
yaml.dump(sqlmesh_config, f)

dbt_data_dir = tmp_path / "dbt_data"
dbt_data_dir.mkdir()
dbt_data_file = dbt_data_dir / "local.db"
dbt_profile_config = {
"test": {
"outputs": {"duckdb": {"type": "duckdb", "path": str(dbt_data_file)}},
"target": "duckdb",
}
}
db_profile_file = dbt_project_dir / "profiles.yml"
with open(db_profile_file, "w", encoding="utf-8") as f:
yaml.dump(dbt_profile_config, f)

context = Context(paths=dbt_project_dir)

# find the model by its sqlmesh fully qualified name
model_fqn = '"local"."main"."simple_model"'
assert model_fqn in context.snapshots

# Verify that node_name is the equivalent dbt one
model = context.snapshots[model_fqn].model
assert model.dbt_name == "model.test_project.simple_model"
82 changes: 82 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json
from sqlmesh.dbt.builtin import _relation_info_to_relation, Config
from sqlmesh.dbt.common import Dependencies
from sqlmesh.dbt.builtin import _relation_info_to_relation
from sqlmesh.dbt.column import (
ColumnConfig,
column_descriptions_to_sqlmesh,
Expand Down Expand Up @@ -2375,3 +2376,84 @@ def test_dynamic_var_names_in_macro(sushi_test_project: Project):
)
converted_model = model_config.to_sqlmesh(context)
assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore


def test_selected_resources_with_selectors():
sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"])

# A plan with a specific model selection
plan_builder = sushi_context.plan_builder(select_models=["sushi.customers"])
plan = plan_builder.build()
assert len(plan.selected_models) == 1
selected_model = list(plan.selected_models)[0]
assert "customers" in selected_model

# Plan without model selections should include all models
plan_builder = sushi_context.plan_builder()
plan = plan_builder.build()
assert plan.selected_models is not None
assert len(plan.selected_models) > 10

# with downstream models should select customers and at least one downstream model
plan_builder = sushi_context.plan_builder(select_models=["sushi.customers+"])
plan = plan_builder.build()
assert plan.selected_models is not None
assert len(plan.selected_models) >= 2
assert any("customers" in model for model in plan.selected_models)

# Test wildcard selection
plan_builder = sushi_context.plan_builder(select_models=["sushi.waiter_*"])
plan = plan_builder.build()
assert plan.selected_models is not None
assert len(plan.selected_models) >= 4
assert all("waiter" in model for model in plan.selected_models)


@pytest.mark.xdist_group("dbt_manifest")
def test_selected_resources_context_variable(
sushi_test_project: Project, sushi_test_dbt_context: Context
):
context = sushi_test_project.context

# empty selected resources
direct_access = context.render("{{ selected_resources }}")
assert direct_access == "[]"

# selected_resources is iterable and count items
test_jinja = """
{%- set resources = [] -%}
{%- for resource in selected_resources -%}
{%- do resources.append(resource) -%}
{%- endfor -%}
{{ resources | length }}
"""
result = context.render(test_jinja)
assert result.strip() == "0"

# selected_resources in conditions
test_condition = """
{%- if selected_resources -%}
has_resources
{%- else -%}
no_resources
{%- endif -%}
"""
result = context.render(test_condition)
assert result.strip() == "no_resources"

# selected resources in dbt format
selected_resources = [
"model.jaffle_shop.customers",
"model.jaffle_shop.items",
"model.jaffle_shop.orders",
]

# check the jinja macros rendering
result = context.render("{{ selected_resources }}", selected_resources=selected_resources)
assert result == selected_resources.__repr__()

result = context.render(test_jinja, selected_resources=selected_resources)
assert result.strip() == "3"

result = context.render(test_condition, selected_resources=selected_resources)
assert result.strip() == "has_resources"