Skip to content

Commit 1e6cf07

Browse files
committed
PR feedback
1 parent 58ba2fc commit 1e6cf07

File tree

6 files changed

+61
-40
lines changed

6 files changed

+61
-40
lines changed

sqlmesh/core/context.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
from sqlmesh.core.reference import ReferenceGraph
9494
from sqlmesh.core.scheduler import Scheduler, CompletionStatus
9595
from sqlmesh.core.schema_loader import create_external_models_file
96-
from sqlmesh.core.selector import Selector
96+
from sqlmesh.core.selector import Selector, NativeSelector
9797
from sqlmesh.core.snapshot import (
9898
DeployabilityIndex,
9999
Snapshot,
@@ -348,8 +348,6 @@ class GenericContext(BaseContext, t.Generic[C]):
348348
load: Whether or not to automatically load all models and macros (default True).
349349
console: The rich instance used for printing out CLI command results.
350350
users: A list of users to make known to SQLMesh.
351-
dbt_mode: A flag to indicate we are running in 'dbt mode' which means that things like
352-
model selections should use the dbt names and not the native SQLMesh names
353351
"""
354352

355353
CONFIG_TYPE: t.Type[C]
@@ -370,7 +368,7 @@ def __init__(
370368
load: bool = True,
371369
users: t.Optional[t.List[User]] = None,
372370
config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
373-
dbt_mode: bool = False,
371+
selector: t.Optional[t.Type[Selector]] = None,
374372
):
375373
self.configs = (
376374
config
@@ -393,7 +391,7 @@ def __init__(
393391
self._engine_adapter: t.Optional[EngineAdapter] = None
394392
self._linters: t.Dict[str, Linter] = {}
395393
self._loaded: bool = False
396-
self._dbt_mode = dbt_mode
394+
self._selector_cls = selector or NativeSelector
397395

398396
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
399397

@@ -2897,15 +2895,14 @@ def _new_state_sync(self) -> StateSync:
28972895
def _new_selector(
28982896
self, models: t.Optional[UniqueKeyDict[str, Model]] = None, dag: t.Optional[DAG[str]] = None
28992897
) -> Selector:
2900-
return Selector(
2898+
return self._selector_cls(
29012899
self.state_reader,
29022900
models=models or self._models,
29032901
context_path=self.path,
29042902
dag=dag,
29052903
default_catalog=self.default_catalog,
29062904
dialect=self.default_dialect,
29072905
cache_dir=self.cache_dir,
2908-
dbt_mode=self._dbt_mode,
29092906
)
29102907

29112908
def _register_notification_targets(self) -> None:

sqlmesh/core/selector.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing as t
55
from pathlib import Path
66
from itertools import zip_longest
7+
import abc
78

89
from sqlglot import exp
910
from sqlglot.errors import ParseError
@@ -27,7 +28,7 @@
2728
from sqlmesh.core.state_sync import StateReader
2829

2930

30-
class Selector:
31+
class Selector(abc.ABC):
3132
def __init__(
3233
self,
3334
state_reader: StateReader,
@@ -37,7 +38,6 @@ def __init__(
3738
default_catalog: t.Optional[str] = None,
3839
dialect: t.Optional[str] = None,
3940
cache_dir: t.Optional[Path] = None,
40-
dbt_mode: bool = False,
4141
):
4242
self._state_reader = state_reader
4343
self._models = models
@@ -46,7 +46,6 @@ def __init__(
4646
self._default_catalog = default_catalog
4747
self._dialect = dialect
4848
self._git_client = GitClient(context_path)
49-
self._dbt_mode = dbt_mode
5049

5150
if dag is None:
5251
self._dag: DAG[str] = DAG()
@@ -243,26 +242,37 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
243242

244243
return evaluate(node)
245244

246-
def _model_fqn(self, model: Model) -> str:
247-
if self._dbt_mode:
248-
dbt_fqn = model.dbt_fqn
249-
if dbt_fqn is None:
250-
raise SQLMeshError("Expecting dbt node information to be populated; it wasnt")
251-
return dbt_fqn
252-
return model.fqn
245+
@abc.abstractmethod
246+
def _model_name(self, model: Model) -> str:
247+
"""Given a model, return the name that a selector pattern contining wildcards should be fnmatch'd on"""
248+
pass
249+
250+
@abc.abstractmethod
251+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
252+
"""Given a pattern, return the keys of the matching models from :all_models"""
253+
pass
254+
255+
256+
class NativeSelector(Selector):
257+
"""Implementation of selectors that matches objects based on SQLMesh native names"""
253258

254259
def _model_name(self, model: Model) -> str:
255-
if self._dbt_mode:
256-
# dbt always matches on the fqn, not the name
257-
return self._model_fqn(model)
258260
return model.name
259261

260262
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
261-
# note: all_models should be keyed by sqlmesh fqn, not dbt fqn
262-
if not self._dbt_mode:
263-
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
264-
return {fqn} if fqn in all_models else set()
263+
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
264+
return {fqn} if fqn in all_models else set()
265265

266+
267+
class DbtSelector(Selector):
268+
"""Implementation of selectors that matches objects based on the DBT names instead of the SQLMesh native names"""
269+
270+
def _model_name(self, model: Model) -> str:
271+
if dbt_fqn := model.dbt_fqn:
272+
return dbt_fqn
273+
raise SQLMeshError("dbt node information must be populated to use dbt selectors")
274+
275+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
266276
# a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers"
267277
# but not a model called "jaffle_shop.customers.staging"
268278
# also a pattern like "aging" should not match "staging" so we need to consider components; not substrings

sqlmesh_dbt/operations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def create(
231231
from sqlmesh.core.console import set_console
232232
from sqlmesh_dbt.console import DbtCliConsole
233233
from sqlmesh.utils.errors import SQLMeshError
234+
from sqlmesh.core.selector import DbtSelector
234235

235236
# clear any existing handlers set up by click/rich as defaults so that once SQLMesh logging config is applied,
236237
# we dont get duplicate messages logged from things like console.log_warning()
@@ -250,8 +251,8 @@ def create(
250251
paths=[project_dir],
251252
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
252253
load=True,
253-
# dbt mode enables selectors to use dbt model fqn's rather than SQLMesh model names
254-
dbt_mode=True,
254+
# DbtSelector selects based on dbt model fqn's rather than SQLMesh model names
255+
selector=DbtSelector,
255256
)
256257

257258
dbt_loader = sqlmesh_context._loaders[0]

tests/core/test_selector.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlmesh.core.environment import Environment
1313
from sqlmesh.core.model import Model, SqlModel
1414
from sqlmesh.core.model.common import ParsableSql
15-
from sqlmesh.core.selector import Selector
15+
from sqlmesh.core.selector import NativeSelector
1616
from sqlmesh.core.snapshot import SnapshotChangeCategory
1717
from sqlmesh.utils import UniqueKeyDict
1818
from sqlmesh.utils.date import now_timestamp
@@ -88,7 +88,7 @@ def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t.
8888
local_models[modified_model_v2.fqn] = modified_model_v2.copy(
8989
update={"mapping_schema": added_model_schema}
9090
)
91-
selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog)
91+
selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog)
9292

9393
_assert_models_equal(
9494
selector.select_models(["db.added_model"], env_name),
@@ -243,7 +243,7 @@ def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot)
243243

244244
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
245245
local_models[modified_model_v2.fqn] = modified_model_v2
246-
selector = Selector(state_reader_mock, local_models)
246+
selector = NativeSelector(state_reader_mock, local_models)
247247

248248
_assert_models_equal(
249249
selector.select_models(["*.modified_model"], env_name, fallback_env_name="prod"),
@@ -305,7 +305,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot):
305305
local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}})
306306
local_models[local_child.fqn] = local_child
307307

308-
selector = Selector(state_reader_mock, local_models)
308+
selector = NativeSelector(state_reader_mock, local_models)
309309

310310
selected = selector.select_models(["db.parent"], env_name)
311311
assert selected[local_child.fqn].render_query() != child.render_query()
@@ -339,7 +339,7 @@ def test_select_models_missing_env(mocker: MockerFixture, make_snapshot):
339339
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
340340
local_models[model.fqn] = model
341341

342-
selector = Selector(state_reader_mock, local_models)
342+
selector = NativeSelector(state_reader_mock, local_models)
343343

344344
assert selector.select_models([model.name], "missing_env").keys() == {model.fqn}
345345
assert not selector.select_models(["missing"], "missing_env")
@@ -563,7 +563,7 @@ def test_expand_model_selections(
563563
)
564564
models[model.fqn] = model
565565

566-
selector = Selector(mocker.Mock(), models)
566+
selector = NativeSelector(mocker.Mock(), models)
567567
assert selector.expand_model_selections(selections) == output
568568

569569

@@ -576,7 +576,7 @@ def test_model_selection_normalized(mocker: MockerFixture, make_snapshot):
576576
dialect="bigquery",
577577
)
578578
models[model.fqn] = model
579-
selector = Selector(mocker.Mock(), models, dialect="bigquery")
579+
selector = NativeSelector(mocker.Mock(), models, dialect="bigquery")
580580
assert selector.expand_model_selections(["db.test_Model"]) == {'"db"."test_Model"'}
581581

582582

@@ -624,7 +624,7 @@ def test_expand_git_selection(
624624
git_client_mock.list_uncommitted_changed_files.return_value = []
625625
git_client_mock.list_committed_changed_files.return_value = [model_a._path, model_c._path]
626626

627-
selector = Selector(mocker.Mock(), models)
627+
selector = NativeSelector(mocker.Mock(), models)
628628
selector._git_client = git_client_mock
629629

630630
assert selector.expand_model_selections(expressions) == expected_fqns
@@ -658,7 +658,7 @@ def test_select_models_with_external_parent(mocker: MockerFixture):
658658
local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
659659
local_models[added_model.fqn] = added_model
660660

661-
selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog)
661+
selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog)
662662

663663
expanded_selections = selector.expand_model_selections(["+*added_model*"])
664664
assert expanded_selections == {added_model.fqn}
@@ -699,7 +699,7 @@ def test_select_models_local_tags_take_precedence_over_remote(
699699
local_models[local_existing.fqn] = local_existing
700700
local_models[local_new.fqn] = local_new
701701

702-
selector = Selector(state_reader_mock, local_models)
702+
selector = NativeSelector(state_reader_mock, local_models)
703703

704704
selected = selector.select_models(["tag:a"], env_name)
705705

tests/dbt/cli/test_selectors.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
22
import pytest
33
from sqlmesh_dbt import selectors
4+
from sqlmesh.core.selector import DbtSelector
45
from sqlmesh.core.context import Context
56
from pathlib import Path
67

@@ -112,6 +113,7 @@ def test_split_unions_and_intersections(
112113
["customers+", "stg_orders"],
113114
{'"jaffle_shop"."main"."customers"', '"jaffle_shop"."main"."stg_orders"'},
114115
),
116+
(["*.staging.stg_c*"], {'"jaffle_shop"."main"."stg_customers"'}),
115117
(["tag:agg"], {'"jaffle_shop"."main"."agg_orders"'}),
116118
(
117119
["staging.stg_customers", "tag:agg"],
@@ -137,6 +139,16 @@ def test_split_unions_and_intersections(
137139
'"jaffle_shop"."main"."agg_orders"',
138140
},
139141
),
142+
(
143+
["tag:b*"],
144+
set(),
145+
),
146+
(
147+
["tag:a*"],
148+
{
149+
'"jaffle_shop"."main"."agg_orders"',
150+
},
151+
),
140152
],
141153
)
142154
def test_select_by_dbt_names(
@@ -155,7 +167,7 @@ def test_select_by_dbt_names(
155167
assert '"jaffle_shop"."main"."agg_orders"' in ctx.models
156168

157169
selector = ctx._new_selector()
158-
assert selector._dbt_mode
170+
assert isinstance(selector, DbtSelector)
159171

160172
sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[])
161173
assert sqlmesh_selector
@@ -205,7 +217,7 @@ def test_exclude_by_dbt_names(
205217
assert '"jaffle_shop"."main"."agg_orders"' in ctx.models
206218

207219
selector = ctx._new_selector()
208-
assert selector._dbt_mode
220+
assert isinstance(selector, DbtSelector)
209221

210222
sqlmesh_selector = selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude)
211223
assert sqlmesh_selector
@@ -251,7 +263,7 @@ def test_selection_and_exclusion_by_dbt_names(
251263
assert '"jaffle_shop"."main"."agg_orders"' in ctx.models
252264

253265
selector = ctx._new_selector()
254-
assert selector._dbt_mode
266+
assert isinstance(selector, DbtSelector)
255267

256268
sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude)
257269
assert sqlmesh_selector

tests/dbt/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from sqlmesh.core.context import Context
10+
from sqlmesh.core.selector import DbtSelector
1011
from sqlmesh.dbt.context import DbtContext
1112
from sqlmesh.dbt.project import Project
1213
from sqlmesh.dbt.target import PostgresConfig
@@ -99,7 +100,7 @@ def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.It
99100
@pytest.fixture
100101
def jaffle_shop_duckdb_context(jaffle_shop_duckdb: Path) -> Context:
101102
init_project_if_required(jaffle_shop_duckdb)
102-
return Context(paths=[jaffle_shop_duckdb], dbt_mode=True)
103+
return Context(paths=[jaffle_shop_duckdb], selector=DbtSelector)
103104

104105

105106
@pytest.fixture()

0 commit comments

Comments
 (0)