Skip to content

Commit 850fcd1

Browse files
authored
Feat(sqlmesh_dbt): Support the --threads CLI option (#5493)
1 parent d2dd2df commit 850fcd1

File tree

6 files changed

+59
-4
lines changed

6 files changed

+59
-4
lines changed

sqlmesh/core/config/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def load_config_from_paths(
177177
dbt_profile_name=kwargs.pop("profile", None),
178178
dbt_target_name=kwargs.pop("target", None),
179179
variables=variables,
180+
threads=kwargs.pop("threads", None),
180181
)
181182
if type(dbt_python_config) != config_type:
182183
dbt_python_config = convert_config_type(dbt_python_config, config_type)

sqlmesh/dbt/loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def sqlmesh_config(
4949
dbt_profile_name: t.Optional[str] = None,
5050
dbt_target_name: t.Optional[str] = None,
5151
variables: t.Optional[t.Dict[str, t.Any]] = None,
52+
threads: t.Optional[int] = None,
5253
register_comments: t.Optional[bool] = None,
5354
**kwargs: t.Any,
5455
) -> Config:
@@ -67,6 +68,10 @@ def sqlmesh_config(
6768
if not issubclass(loader, DbtLoader):
6869
raise ConfigError("The loader must be a DbtLoader.")
6970

71+
if threads is not None:
72+
# the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks
73+
profile.target.threads = threads
74+
7075
return Config(
7176
loader=loader,
7277
model_defaults=model_defaults,

sqlmesh_dbt/cli.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
import functools
99

1010

11-
def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations:
11+
def _get_dbt_operations(
12+
ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], threads: t.Optional[int] = None
13+
) -> DbtOperations:
1214
if not isinstance(ctx.obj, functools.partial):
1315
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")
1416

15-
dbt_operations = ctx.obj(vars=vars)
17+
dbt_operations = ctx.obj(vars=vars, threads=threads)
1618

1719
if not isinstance(dbt_operations, DbtOperations):
1820
raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}")
@@ -128,16 +130,22 @@ def dbt(
128130
@click.option(
129131
"--empty/--no-empty", default=False, help="If specified, limit input refs and sources"
130132
)
133+
@click.option(
134+
"--threads",
135+
type=int,
136+
help="Specify number of threads to use while executing models. Overrides settings in profiles.yml.",
137+
)
131138
@vars_option
132139
@click.pass_context
133140
def run(
134141
ctx: click.Context,
135142
vars: t.Optional[t.Dict[str, t.Any]],
143+
threads: t.Optional[int],
136144
env: t.Optional[str] = None,
137145
**kwargs: t.Any,
138146
) -> None:
139147
"""Compile SQL and execute against the current target database."""
140-
_get_dbt_operations(ctx, vars).run(environment=env, **kwargs)
148+
_get_dbt_operations(ctx, vars, threads).run(environment=env, **kwargs)
141149

142150

143151
@dbt.command(name="list")

sqlmesh_dbt/operations.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def create(
235235
profile: t.Optional[str] = None,
236236
target: t.Optional[str] = None,
237237
vars: t.Optional[t.Dict[str, t.Any]] = None,
238+
threads: t.Optional[int] = None,
238239
debug: bool = False,
239240
) -> DbtOperations:
240241
with Progress(transient=True) as progress:
@@ -265,7 +266,9 @@ def create(
265266

266267
sqlmesh_context = Context(
267268
paths=[project_dir],
268-
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
269+
config_loader_kwargs=dict(
270+
profile=profile, target=target, variables=vars, threads=threads
271+
),
269272
load=True,
270273
# DbtSelector selects based on dbt model fqn's rather than SQLMesh model names
271274
selector=DbtSelector,

tests/dbt/cli/test_operations.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,33 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path):
333333
assert not plan.empty_backfill
334334
assert not plan.skip_backfill
335335
assert plan.models_to_backfill == set(['"jaffle_shop"."main"."stg_customers"'])
336+
337+
338+
def test_create_sets_concurrent_tasks_based_on_threads(create_empty_project: EmptyProjectCreator):
339+
project_dir, _ = create_empty_project(project_name="test")
340+
341+
# add a postgres target because duckdb overrides to concurrent_tasks=1 regardless of what gets specified
342+
profiles_yml_file = project_dir / "profiles.yml"
343+
profiles_yml = yaml.load(profiles_yml_file)
344+
profiles_yml["test"]["outputs"]["postgres"] = {
345+
"type": "postgres",
346+
"host": "localhost",
347+
"port": 5432,
348+
"user": "postgres",
349+
"password": "postgres",
350+
"dbname": "test",
351+
"schema": "test",
352+
}
353+
profiles_yml_file.write_text(yaml.dump(profiles_yml))
354+
355+
operations = create(project_dir=project_dir, target="postgres")
356+
357+
assert operations.context.concurrent_tasks == 1 # 1 is the default
358+
359+
operations = create(project_dir=project_dir, threads=16, target="postgres")
360+
361+
assert operations.context.concurrent_tasks == 16
362+
assert all(
363+
g.connection and g.connection.concurrent_tasks == 16
364+
for g in operations.context.config.gateways.values()
365+
)

tests/dbt/cli/test_run.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,11 @@ def test_run_with_changes_and_full_refresh(
8383
("foo", "bar", "changed"),
8484
("baz", "bing", "changed"),
8585
]
86+
87+
88+
def test_run_with_threads(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
89+
result = invoke_cli(["run", "--threads", "4"])
90+
assert result.exit_code == 0
91+
assert not result.exception
92+
93+
assert "Model batches executed" in result.output

0 commit comments

Comments
 (0)