Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def load_config_from_paths(
dbt_profile_name=kwargs.pop("profile", None),
dbt_target_name=kwargs.pop("target", None),
variables=variables,
threads=kwargs.pop("threads", None),
)
if type(dbt_python_config) != config_type:
dbt_python_config = convert_config_type(dbt_python_config, config_type)
Expand Down
5 changes: 5 additions & 0 deletions sqlmesh/dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def sqlmesh_config(
dbt_profile_name: t.Optional[str] = None,
dbt_target_name: t.Optional[str] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
threads: t.Optional[int] = None,
register_comments: t.Optional[bool] = None,
**kwargs: t.Any,
) -> Config:
Expand All @@ -67,6 +68,10 @@ def sqlmesh_config(
if not issubclass(loader, DbtLoader):
raise ConfigError("The loader must be a DbtLoader.")

if threads is not None:
# the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks
profile.target.threads = threads

return Config(
loader=loader,
model_defaults=model_defaults,
Expand Down
14 changes: 11 additions & 3 deletions sqlmesh_dbt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import functools


def _get_dbt_operations(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]]) -> DbtOperations:
def _get_dbt_operations(
ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], threads: t.Optional[int] = None
) -> DbtOperations:
if not isinstance(ctx.obj, functools.partial):
raise ValueError(f"Unexpected click context object: {type(ctx.obj)}")

dbt_operations = ctx.obj(vars=vars)
dbt_operations = ctx.obj(vars=vars, threads=threads)

if not isinstance(dbt_operations, DbtOperations):
raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}")
Expand Down Expand Up @@ -128,16 +130,22 @@ def dbt(
@click.option(
"--empty/--no-empty", default=False, help="If specified, limit input refs and sources"
)
@click.option(
"--threads",
type=int,
help="Specify number of threads to use while executing models. Overrides settings in profiles.yml.",
)
@vars_option
@click.pass_context
def run(
ctx: click.Context,
vars: t.Optional[t.Dict[str, t.Any]],
threads: t.Optional[int],
env: t.Optional[str] = None,
**kwargs: t.Any,
) -> None:
"""Compile SQL and execute against the current target database."""
_get_dbt_operations(ctx, vars).run(environment=env, **kwargs)
_get_dbt_operations(ctx, vars, threads).run(environment=env, **kwargs)


@dbt.command(name="list")
Expand Down
5 changes: 4 additions & 1 deletion sqlmesh_dbt/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def create(
profile: t.Optional[str] = None,
target: t.Optional[str] = None,
vars: t.Optional[t.Dict[str, t.Any]] = None,
threads: t.Optional[int] = None,
debug: bool = False,
) -> DbtOperations:
with Progress(transient=True) as progress:
Expand Down Expand Up @@ -265,7 +266,9 @@ def create(

sqlmesh_context = Context(
paths=[project_dir],
config_loader_kwargs=dict(profile=profile, target=target, variables=vars),
config_loader_kwargs=dict(
profile=profile, target=target, variables=vars, threads=threads
),
load=True,
# DbtSelector selects based on dbt model fqn's rather than SQLMesh model names
selector=DbtSelector,
Expand Down
30 changes: 30 additions & 0 deletions tests/dbt/cli/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,33 @@ def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path):
assert not plan.empty_backfill
assert not plan.skip_backfill
assert plan.models_to_backfill == set(['"jaffle_shop"."main"."stg_customers"'])


def test_create_sets_concurrent_tasks_based_on_threads(create_empty_project: EmptyProjectCreator):
project_dir, _ = create_empty_project(project_name="test")

# add a postgres target because duckdb overrides to concurrent_tasks=1 regardless of what gets specified
profiles_yml_file = project_dir / "profiles.yml"
profiles_yml = yaml.load(profiles_yml_file)
profiles_yml["test"]["outputs"]["postgres"] = {
"type": "postgres",
"host": "localhost",
"port": 5432,
"user": "postgres",
"password": "postgres",
"dbname": "test",
"schema": "test",
}
profiles_yml_file.write_text(yaml.dump(profiles_yml))

operations = create(project_dir=project_dir, target="postgres")

assert operations.context.concurrent_tasks == 1 # 1 is the default

operations = create(project_dir=project_dir, threads=16, target="postgres")

assert operations.context.concurrent_tasks == 16
assert all(
g.connection and g.connection.concurrent_tasks == 16
for g in operations.context.config.gateways.values()
)
8 changes: 8 additions & 0 deletions tests/dbt/cli/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,11 @@ def test_run_with_changes_and_full_refresh(
("foo", "bar", "changed"),
("baz", "bing", "changed"),
]


def test_run_with_threads(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]):
result = invoke_cli(["run", "--threads", "4"])
assert result.exit_code == 0
assert not result.exception

assert "Model batches executed" in result.output