Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
from sqlmesh.core.reference import ReferenceGraph
from sqlmesh.core.scheduler import Scheduler, CompletionStatus
from sqlmesh.core.schema_loader import create_external_models_file
from sqlmesh.core.selector import Selector
from sqlmesh.core.selector import Selector, NativeSelector
from sqlmesh.core.snapshot import (
DeployabilityIndex,
Snapshot,
Expand Down Expand Up @@ -368,6 +368,7 @@ def __init__(
load: bool = True,
users: t.Optional[t.List[User]] = None,
config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
selector: t.Optional[t.Type[Selector]] = None,
):
self.configs = (
config
Expand All @@ -390,6 +391,7 @@ def __init__(
self._engine_adapter: t.Optional[EngineAdapter] = None
self._linters: t.Dict[str, Linter] = {}
self._loaded: bool = False
self._selector_cls = selector or NativeSelector

self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))

Expand Down Expand Up @@ -2893,7 +2895,7 @@ def _new_state_sync(self) -> StateSync:
def _new_selector(
self, models: t.Optional[UniqueKeyDict[str, Model]] = None, dag: t.Optional[DAG[str]] = None
) -> Selector:
return Selector(
return self._selector_cls(
self.state_reader,
models=models or self._models,
context_path=self.path,
Expand Down
77 changes: 71 additions & 6 deletions sqlmesh/core/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import fnmatch
import typing as t
from pathlib import Path
from itertools import zip_longest
import abc

from sqlglot import exp
from sqlglot.errors import ParseError
Expand All @@ -26,7 +28,7 @@
from sqlmesh.core.state_sync import StateReader


class Selector:
class Selector(abc.ABC):
def __init__(
self,
state_reader: StateReader,
Expand Down Expand Up @@ -167,13 +169,13 @@ def get_model(fqn: str) -> t.Optional[Model]:
def expand_model_selections(
self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None
) -> t.Set[str]:
"""Expands a set of model selections into a set of model names.
"""Expands a set of model selections into a set of model fqns that can be looked up in the Context.

Args:
model_selections: A set of model selections.

Returns:
A set of model names.
A set of model fqns.
"""

node = parse(" | ".join(f"({s})" for s in model_selections))
Expand All @@ -194,10 +196,9 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
return {
fqn
for fqn, model in all_models.items()
if fnmatch.fnmatchcase(model.name, node.this)
if fnmatch.fnmatchcase(self._model_name(model), node.this)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the wildcard matching happens outside _pattern_to_model_fqns which seems straightforward and should work, but could you add a test for this as well to ensure nothing breaks in the future if someone modifies this? unless if I missed it and there is a wildcard test for dbt model names already

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well spotted, I added some tests to exercise this codepath based on the dbt docs.

Note that the patterns currently need a * in them to trigger this branch (even though ?, [ and ] could be used as well) so this will need more work in future if we find people using these in the wild

}
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
return {fqn} if fqn in all_models else set()
return self._pattern_to_model_fqns(pattern, all_models)
if isinstance(node, exp.And):
return evaluate(node.left) & evaluate(node.right)
if isinstance(node, exp.Or):
Expand Down Expand Up @@ -241,6 +242,70 @@ def evaluate(node: exp.Expression) -> t.Set[str]:

return evaluate(node)

@abc.abstractmethod
def _model_name(self, model: Model) -> str:
"""Given a model, return the name that a selector pattern contining wildcards should be fnmatch'd on"""
pass

@abc.abstractmethod
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
"""Given a pattern, return the keys of the matching models from :all_models"""
pass


class NativeSelector(Selector):
"""Implementation of selectors that matches objects based on SQLMesh native names"""

def _model_name(self, model: Model) -> str:
return model.name

def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
return {fqn} if fqn in all_models else set()


class DbtSelector(Selector):
"""Implementation of selectors that matches objects based on the DBT names instead of the SQLMesh native names"""

def _model_name(self, model: Model) -> str:
if dbt_fqn := model.dbt_fqn:
return dbt_fqn
raise SQLMeshError("dbt node information must be populated to use dbt selectors")

def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
# a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers"
# but not a model called "jaffle_shop.customers.staging"
# also a pattern like "aging" should not match "staging" so we need to consider components; not substrings
pattern_components = pattern.split(".")
first_pattern_component = pattern_components[0]
matches = set()
for fqn, model in all_models.items():
if not model.dbt_fqn:
continue

dbt_fqn_components = model.dbt_fqn.split(".")
try:
starting_idx = dbt_fqn_components.index(first_pattern_component)
except ValueError:
continue
for pattern_component, fqn_component in zip_longest(
pattern_components, dbt_fqn_components[starting_idx:]
):
if pattern_component and not fqn_component:
# the pattern still goes but we have run out of fqn components to match; no match
break
if fqn_component and not pattern_component:
# all elements of the pattern have matched elements of the fqn; match
matches.add(fqn)
break
if pattern_component != fqn_component:
# the pattern explicitly doesnt match a component; no match
break
else:
# called if no explicit break, indicating all components of the pattern matched all components of the fqn
matches.add(fqn)
return matches


class SelectorDialect(Dialect):
IDENTIFIERS_CAN_START_WITH_DIGIT = True
Expand Down
5 changes: 4 additions & 1 deletion sqlmesh_dbt/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _plan_builder_options(
options.update(
dict(
# Add every selected model as a restatement to force them to get repopulated from scratch
restate_models=list(self.context.models)
restate_models=[m.dbt_fqn for m in self.context.models.values() if m.dbt_fqn]
if not select_models
else select_models,
# by default in SQLMesh, restatements only operate on what has been committed to state.
Expand Down Expand Up @@ -231,6 +231,7 @@ def create(
from sqlmesh.core.console import set_console
from sqlmesh_dbt.console import DbtCliConsole
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.core.selector import DbtSelector

# clear any existing handlers set up by click/rich as defaults so that once SQLMesh logging config is applied,
# we dont get duplicate messages logged from things like console.log_warning()
Expand All @@ -250,6 +251,8 @@ def create(
paths=[project_dir],
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
load=True,
# DbtSelector selects based on dbt model fqn's rather than SQLMesh model names
selector=DbtSelector,
)

dbt_loader = sqlmesh_context._loaders[0]
Expand Down
20 changes: 10 additions & 10 deletions tests/core/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlmesh.core.environment import Environment
from sqlmesh.core.model import Model, SqlModel
from sqlmesh.core.model.common import ParsableSql
from sqlmesh.core.selector import Selector
from sqlmesh.core.selector import NativeSelector
from sqlmesh.core.snapshot import SnapshotChangeCategory
from sqlmesh.utils import UniqueKeyDict
from sqlmesh.utils.date import now_timestamp
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t.
local_models[modified_model_v2.fqn] = modified_model_v2.copy(
update={"mapping_schema": added_model_schema}
)
selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog)
selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog)

_assert_models_equal(
selector.select_models(["db.added_model"], env_name),
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot)

local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
local_models[modified_model_v2.fqn] = modified_model_v2
selector = Selector(state_reader_mock, local_models)
selector = NativeSelector(state_reader_mock, local_models)

_assert_models_equal(
selector.select_models(["*.modified_model"], env_name, fallback_env_name="prod"),
Expand Down Expand Up @@ -305,7 +305,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot):
local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}})
local_models[local_child.fqn] = local_child

selector = Selector(state_reader_mock, local_models)
selector = NativeSelector(state_reader_mock, local_models)

selected = selector.select_models(["db.parent"], env_name)
assert selected[local_child.fqn].render_query() != child.render_query()
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_select_models_missing_env(mocker: MockerFixture, make_snapshot):
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
local_models[model.fqn] = model

selector = Selector(state_reader_mock, local_models)
selector = NativeSelector(state_reader_mock, local_models)

assert selector.select_models([model.name], "missing_env").keys() == {model.fqn}
assert not selector.select_models(["missing"], "missing_env")
Expand Down Expand Up @@ -563,7 +563,7 @@ def test_expand_model_selections(
)
models[model.fqn] = model

selector = Selector(mocker.Mock(), models)
selector = NativeSelector(mocker.Mock(), models)
assert selector.expand_model_selections(selections) == output


Expand All @@ -576,7 +576,7 @@ def test_model_selection_normalized(mocker: MockerFixture, make_snapshot):
dialect="bigquery",
)
models[model.fqn] = model
selector = Selector(mocker.Mock(), models, dialect="bigquery")
selector = NativeSelector(mocker.Mock(), models, dialect="bigquery")
assert selector.expand_model_selections(["db.test_Model"]) == {'"db"."test_Model"'}


Expand Down Expand Up @@ -624,7 +624,7 @@ def test_expand_git_selection(
git_client_mock.list_uncommitted_changed_files.return_value = []
git_client_mock.list_committed_changed_files.return_value = [model_a._path, model_c._path]

selector = Selector(mocker.Mock(), models)
selector = NativeSelector(mocker.Mock(), models)
selector._git_client = git_client_mock

assert selector.expand_model_selections(expressions) == expected_fqns
Expand Down Expand Up @@ -658,7 +658,7 @@ def test_select_models_with_external_parent(mocker: MockerFixture):
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
local_models[added_model.fqn] = added_model

selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog)
selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog)

expanded_selections = selector.expand_model_selections(["+*added_model*"])
assert expanded_selections == {added_model.fqn}
Expand Down Expand Up @@ -699,7 +699,7 @@ def test_select_models_local_tags_take_precedence_over_remote(
local_models[local_existing.fqn] = local_existing
local_models[local_new.fqn] = local_new

selector = Selector(state_reader_mock, local_models)
selector = NativeSelector(state_reader_mock, local_models)

selected = selector.select_models(["tag:a"], env_name)

Expand Down
8 changes: 4 additions & 4 deletions tests/dbt/cli/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):


def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
result = invoke_cli(["list", "--select", "main.raw_customers+"])
result = invoke_cli(["list", "--select", "raw_customers+"])

assert result.exit_code == 0
assert not result.exception
Expand All @@ -34,7 +34,7 @@ def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Resul

def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
# single exclude
result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"])
result = invoke_cli(["list", "--select", "raw_customers+", "--exclude", "orders"])

assert result.exit_code == 0
assert not result.exception
Expand All @@ -49,8 +49,8 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..

# multiple exclude
for args in (
["--select", "main.stg_orders+", "--exclude", "main.customers", "--exclude", "main.orders"],
["--select", "main.stg_orders+", "--exclude", "main.customers main.orders"],
["--select", "stg_orders+", "--exclude", "customers", "--exclude", "orders"],
["--select", "stg_orders+", "--exclude", "customers orders"],
):
result = invoke_cli(["list", *args])
assert result.exit_code == 0
Expand Down
10 changes: 5 additions & 5 deletions tests/dbt/cli/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
assert plan.selected_models_to_backfill is None
assert {s.name for s in plan.snapshots} == {k for k in operations.context.snapshots}

plan = operations.run(select=["main.stg_orders+"])
plan = operations.run(select=["stg_orders+"])
assert plan.environment.name == "prod"
assert console.no_prompts is True
assert console.no_diff is True
Expand All @@ -155,7 +155,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
plan.selected_models_to_backfill | {standalone_audit_name}
)

plan = operations.run(select=["main.stg_orders+"], exclude=["main.customers"])
plan = operations.run(select=["stg_orders+"], exclude=["customers"])
assert plan.environment.name == "prod"
assert console.no_prompts is True
assert console.no_diff is True
Expand All @@ -171,7 +171,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
plan.selected_models_to_backfill | {standalone_audit_name}
)

plan = operations.run(exclude=["main.customers"])
plan = operations.run(exclude=["customers"])
assert plan.environment.name == "prod"
assert console.no_prompts is True
assert console.no_diff is True
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_run_option_mapping_dev(jaffle_shop_duckdb: Path):
assert plan.skip_backfill is True
assert plan.selected_models_to_backfill == {'"jaffle_shop"."main"."new_model"'}

plan = operations.run(environment="dev", select=["main.stg_orders+"])
plan = operations.run(environment="dev", select=["stg_orders+"])
assert plan.environment.name == "dev"
assert console.no_prompts is True
assert console.no_diff is True
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path):
console = PlanCapturingConsole()
operations.context.console = console

plan = operations.run(select=["main.stg_customers"], full_refresh=True)
plan = operations.run(select=["stg_customers"], full_refresh=True)
assert len(plan.restatements) == 1
assert list(plan.restatements)[0].name == '"jaffle_shop"."main"."stg_customers"'

Expand Down
2 changes: 1 addition & 1 deletion tests/dbt/cli/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[...
assert result.exit_code == 0
assert "main.orders" in result.output

result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"])
result = invoke_cli(["run", "--select", "raw_customers+", "--exclude", "orders"])

assert result.exit_code == 0
assert not result.exception
Expand Down
Loading