Skip to content

Commit f5ce33c

Browse files
committed
Feat(sqlmesh_dbt): Select based on dbt name, not sqlmesh name
1 parent 34dc9fd commit f5ce33c

File tree

8 files changed

+259
-17
lines changed

8 files changed

+259
-17
lines changed

sqlmesh/core/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ class GenericContext(BaseContext, t.Generic[C]):
348348
load: Whether or not to automatically load all models and macros (default True).
349349
console: The rich instance used for printing out CLI command results.
350350
users: A list of users to make known to SQLMesh.
351+
dbt_mode: A flag to indicate we are running in 'dbt mode' which means that things like
352+
model selections should use the dbt names and not the native SQLMesh names
351353
"""
352354

353355
CONFIG_TYPE: t.Type[C]
@@ -368,6 +370,7 @@ def __init__(
368370
load: bool = True,
369371
users: t.Optional[t.List[User]] = None,
370372
config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
373+
dbt_mode: bool = False,
371374
):
372375
self.configs = (
373376
config
@@ -390,6 +393,7 @@ def __init__(
390393
self._engine_adapter: t.Optional[EngineAdapter] = None
391394
self._linters: t.Dict[str, Linter] = {}
392395
self._loaded: bool = False
396+
self._dbt_mode = dbt_mode
393397

394398
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
395399

@@ -2901,6 +2905,7 @@ def _new_selector(
29012905
default_catalog=self.default_catalog,
29022906
dialect=self.default_dialect,
29032907
cache_dir=self.cache_dir,
2908+
dbt_mode=self._dbt_mode,
29042909
)
29052910

29062911
def _register_notification_targets(self) -> None:

sqlmesh/core/selector.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import fnmatch
44
import typing as t
55
from pathlib import Path
6+
from itertools import zip_longest
67

78
from sqlglot import exp
89
from sqlglot.errors import ParseError
@@ -36,6 +37,7 @@ def __init__(
3637
default_catalog: t.Optional[str] = None,
3738
dialect: t.Optional[str] = None,
3839
cache_dir: t.Optional[Path] = None,
40+
dbt_mode: bool = False,
3941
):
4042
self._state_reader = state_reader
4143
self._models = models
@@ -44,6 +46,7 @@ def __init__(
4446
self._default_catalog = default_catalog
4547
self._dialect = dialect
4648
self._git_client = GitClient(context_path)
49+
self._dbt_mode = dbt_mode
4750

4851
if dag is None:
4952
self._dag: DAG[str] = DAG()
@@ -167,13 +170,13 @@ def get_model(fqn: str) -> t.Optional[Model]:
167170
def expand_model_selections(
168171
self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None
169172
) -> t.Set[str]:
170-
"""Expands a set of model selections into a set of model names.
173+
"""Expands a set of model selections into a set of model fqns that can be looked up in the Context.
171174
172175
Args:
173176
model_selections: A set of model selections.
174177
175178
Returns:
176-
A set of model names.
179+
A set of model fqns.
177180
"""
178181

179182
node = parse(" | ".join(f"({s})" for s in model_selections))
@@ -194,10 +197,9 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
194197
return {
195198
fqn
196199
for fqn, model in all_models.items()
197-
if fnmatch.fnmatchcase(model.name, node.this)
200+
if fnmatch.fnmatchcase(self._model_name(model), node.this)
198201
}
199-
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
200-
return {fqn} if fqn in all_models else set()
202+
return self._pattern_to_model_fqns(pattern, all_models)
201203
if isinstance(node, exp.And):
202204
return evaluate(node.left) & evaluate(node.right)
203205
if isinstance(node, exp.Or):
@@ -241,6 +243,59 @@ def evaluate(node: exp.Expression) -> t.Set[str]:
241243

242244
return evaluate(node)
243245

246+
def _model_fqn(self, model: Model) -> str:
247+
if self._dbt_mode:
248+
dbt_fqn = model.dbt_fqn
249+
if dbt_fqn is None:
250+
raise SQLMeshError("Expecting dbt node information to be populated; it wasnt")
251+
return dbt_fqn
252+
return model.fqn
253+
254+
def _model_name(self, model: Model) -> str:
255+
if self._dbt_mode:
256+
# dbt always matches on the fqn, not the name
257+
return self._model_fqn(model)
258+
return model.name
259+
260+
def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Model]) -> t.Set[str]:
261+
# note: all_models should be keyed by sqlmesh fqn, not dbt fqn
262+
if not self._dbt_mode:
263+
fqn = normalize_model_name(pattern, self._default_catalog, self._dialect)
264+
return {fqn} if fqn in all_models else set()
265+
266+
# a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers"
267+
# but not a model called "jaffle_shop.customers.staging"
268+
# also a pattern like "aging" should not match "staging" so we need to consider components; not substrings
269+
pattern_components = pattern.split(".")
270+
first_pattern_component = pattern_components[0]
271+
matches = set()
272+
for fqn, model in all_models.items():
273+
if not model.dbt_fqn:
274+
continue
275+
276+
dbt_fqn_components = model.dbt_fqn.split(".")
277+
try:
278+
starting_idx = dbt_fqn_components.index(first_pattern_component)
279+
except ValueError:
280+
continue
281+
for pattern_component, fqn_component in zip_longest(
282+
pattern_components, dbt_fqn_components[starting_idx:]
283+
):
284+
if pattern_component and not fqn_component:
285+
# the pattern still goes but we have run out of fqn components to match; no match
286+
break
287+
if fqn_component and not pattern_component:
288+
# all elements of the pattern have matched elements of the fqn; match
289+
matches.add(fqn)
290+
break
291+
if pattern_component != fqn_component:
292+
# the pattern explicitly doesnt match a component; no match
293+
break
294+
else:
295+
# called if no explicit break, indicating all components of the pattern matched all components of the fqn
296+
matches.add(fqn)
297+
return matches
298+
244299

245300
class SelectorDialect(Dialect):
246301
IDENTIFIERS_CAN_START_WITH_DIGIT = True

sqlmesh_dbt/operations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _plan_builder_options(
185185
options.update(
186186
dict(
187187
# Add every selected model as a restatement to force them to get repopulated from scratch
188-
restate_models=list(self.context.models)
188+
restate_models=[m.dbt_fqn for m in self.context.models.values() if m.dbt_fqn]
189189
if not select_models
190190
else select_models,
191191
# by default in SQLMesh, restatements only operate on what has been committed to state.
@@ -250,6 +250,8 @@ def create(
250250
paths=[project_dir],
251251
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
252252
load=True,
253+
# dbt mode enables selectors to use dbt model fqn's rather than SQLMesh model names
254+
dbt_mode=True,
253255
)
254256

255257
dbt_loader = sqlmesh_context._loaders[0]

tests/dbt/cli/test_list.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
1919

2020

2121
def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
22-
result = invoke_cli(["list", "--select", "main.raw_customers+"])
22+
result = invoke_cli(["list", "--select", "raw_customers+"])
2323

2424
assert result.exit_code == 0
2525
assert not result.exception
@@ -34,7 +34,7 @@ def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Resul
3434

3535
def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
3636
# single exclude
37-
result = invoke_cli(["list", "--select", "main.raw_customers+", "--exclude", "main.orders"])
37+
result = invoke_cli(["list", "--select", "raw_customers+", "--exclude", "orders"])
3838

3939
assert result.exit_code == 0
4040
assert not result.exception
@@ -49,8 +49,8 @@ def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..
4949

5050
# multiple exclude
5151
for args in (
52-
["--select", "main.stg_orders+", "--exclude", "main.customers", "--exclude", "main.orders"],
53-
["--select", "main.stg_orders+", "--exclude", "main.customers main.orders"],
52+
["--select", "stg_orders+", "--exclude", "customers", "--exclude", "orders"],
53+
["--select", "stg_orders+", "--exclude", "customers orders"],
5454
):
5555
result = invoke_cli(["list", *args])
5656
assert result.exit_code == 0

tests/dbt/cli/test_operations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
138138
assert plan.selected_models_to_backfill is None
139139
assert {s.name for s in plan.snapshots} == {k for k in operations.context.snapshots}
140140

141-
plan = operations.run(select=["main.stg_orders+"])
141+
plan = operations.run(select=["stg_orders+"])
142142
assert plan.environment.name == "prod"
143143
assert console.no_prompts is True
144144
assert console.no_diff is True
@@ -155,7 +155,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
155155
plan.selected_models_to_backfill | {standalone_audit_name}
156156
)
157157

158-
plan = operations.run(select=["main.stg_orders+"], exclude=["main.customers"])
158+
plan = operations.run(select=["stg_orders+"], exclude=["customers"])
159159
assert plan.environment.name == "prod"
160160
assert console.no_prompts is True
161161
assert console.no_diff is True
@@ -171,7 +171,7 @@ def test_run_option_mapping(jaffle_shop_duckdb: Path):
171171
plan.selected_models_to_backfill | {standalone_audit_name}
172172
)
173173

174-
plan = operations.run(exclude=["main.customers"])
174+
plan = operations.run(exclude=["customers"])
175175
assert plan.environment.name == "prod"
176176
assert console.no_prompts is True
177177
assert console.no_diff is True
@@ -238,7 +238,7 @@ def test_run_option_mapping_dev(jaffle_shop_duckdb: Path):
238238
assert plan.skip_backfill is True
239239
assert plan.selected_models_to_backfill == {'"jaffle_shop"."main"."new_model"'}
240240

241-
plan = operations.run(environment="dev", select=["main.stg_orders+"])
241+
plan = operations.run(environment="dev", select=["stg_orders+"])
242242
assert plan.environment.name == "dev"
243243
assert console.no_prompts is True
244244
assert console.no_diff is True
@@ -325,7 +325,7 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path):
325325
console = PlanCapturingConsole()
326326
operations.context.console = console
327327

328-
plan = operations.run(select=["main.stg_customers"], full_refresh=True)
328+
plan = operations.run(select=["stg_customers"], full_refresh=True)
329329
assert len(plan.restatements) == 1
330330
assert list(plan.restatements)[0].name == '"jaffle_shop"."main"."stg_customers"'
331331

tests/dbt/cli/test_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[...
2727
assert result.exit_code == 0
2828
assert "main.orders" in result.output
2929

30-
result = invoke_cli(["run", "--select", "main.raw_customers+", "--exclude", "main.orders"])
30+
result = invoke_cli(["run", "--select", "raw_customers+", "--exclude", "orders"])
3131

3232
assert result.exit_code == 0
3333
assert not result.exception

0 commit comments

Comments
 (0)