Skip to content

Commit 2de2d05

Browse files
committed
Feat: Add 'state_schema_naming_pattern' to infer the state schema per dbt target
1 parent a303011 commit 2de2d05

File tree

9 files changed

+139
-10
lines changed

9 files changed

+139
-10
lines changed

sqlmesh/cli/project_init.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,13 @@ def _gen_config(
116116
- invalidselectstarexpansion
117117
- noambiguousprojections
118118
""",
119-
ProjectTemplate.DBT: f"""# --- Virtual Data Environment Mode ---
119+
ProjectTemplate.DBT: f"""# --- State ---
120+
# This default configuration ensures that each dbt target gets its own isolated state.
121+
# If this is undesirable, you may configure the state connection manually.
122+
# https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/?h=dbt#selecting-a-different-state-connection
123+
state_schema_naming_pattern: sqlmesh_state_@{{dbt_profile_name}}_@{{dbt_target_name}}
124+
125+
# --- Virtual Data Environment Mode ---
120126
# Enable Virtual Data Environments (VDE) for *development* environments.
121127
# Note that the production environment in dbt projects is not virtual by default to maintain compatibility with existing tooling.
122128
# https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#virtual-data-environment-modes

sqlmesh/core/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@
3636
from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig
3737
from sqlmesh.core.config.linter import LinterConfig as LinterConfig
3838
from sqlmesh.core.config.plan import PlanConfig as PlanConfig
39-
from sqlmesh.core.config.root import Config as Config
39+
from sqlmesh.core.config.root import Config as Config, DbtConfigInfo as DbtConfigInfo
4040
from sqlmesh.core.config.run import RunConfig as RunConfig
4141
from sqlmesh.core.config.scheduler import BuiltInSchedulerConfig as BuiltInSchedulerConfig

sqlmesh/core/config/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,17 @@ def update_with(self: T, other: t.Union[t.Dict[str, t.Any], T]) -> T:
140140
setattr(updated, field, value)
141141

142142
return updated
143+
144+
145+
class DbtConfigInfo(PydanticModel):
146+
"""
147+
This is like DbtNodeInfo except it applies to config instead of DAG nodes.
148+
149+
It's intended to capture information from a dbt project loaded by the DbtLoader so that it can be used for things like
150+
variable substitutions in regular project config.
151+
"""
152+
153+
profile_name: str
154+
"""Which profile in the dbt project is being used"""
155+
target_name: str
156+
"""Which target of the specified profile is being used"""

sqlmesh/core/config/gateway.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import typing as t
44

5-
from sqlmesh.core import constants as c
65
from sqlmesh.core.config.base import BaseConfig
76
from sqlmesh.core.config.model import ModelDefaultsConfig
87
from sqlmesh.core.config.common import variables_validator
@@ -33,7 +32,7 @@ class GatewayConfig(BaseConfig):
3332
state_connection: t.Optional[SerializableConnectionConfig] = None
3433
test_connection: t.Optional[SerializableConnectionConfig] = None
3534
scheduler: t.Optional[SchedulerConfig] = None
36-
state_schema: t.Optional[str] = c.SQLMESH
35+
state_schema: t.Optional[str] = None
3736
variables: t.Dict[str, t.Any] = {}
3837
model_defaults: t.Optional[ModelDefaultsConfig] = None
3938

sqlmesh/core/config/root.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import typing as t
66
import zlib
7+
import logging
78

89
from pydantic import Field
910
from pydantic.functional_validators import BeforeValidator
@@ -19,7 +20,7 @@
1920
TableNamingConvention,
2021
VirtualEnvironmentMode,
2122
)
22-
from sqlmesh.core.config.base import BaseConfig, UpdateStrategy
23+
from sqlmesh.core.config.base import BaseConfig, UpdateStrategy, DbtConfigInfo
2324
from sqlmesh.core.config.common import variables_validator, compile_regex_mapping
2425
from sqlmesh.core.config.connection import (
2526
ConnectionConfig,
@@ -49,6 +50,8 @@
4950
from sqlmesh.utils.errors import ConfigError
5051
from sqlmesh.utils.pydantic import model_validator
5152

53+
logger = logging.getLogger(__name__)
54+
5255

5356
def validate_no_past_ttl(v: str) -> str:
5457
current_time = now()
@@ -96,6 +99,8 @@ class Config(BaseConfig):
9699
default_test_connection: The default connection to use for tests if one is not specified in a gateway.
97100
default_scheduler: The default scheduler configuration to use if one is not specified in a gateway.
98101
default_gateway: The default gateway.
102+
state_schema_naming_pattern: A pattern supporting variable substitutions to determine the state schema name, rather than just using 'sqlmesh'.
103+
Only applies when the state schema is not explicitly set in the gateway config
99104
notification_targets: The notification targets to use.
100105
project: The project name of this config. Used for multi-repo setups.
101106
snapshot_ttl: The period of time that a model snapshot that is not a part of any environment should exist before being deleted.
@@ -128,6 +133,7 @@ class Config(BaseConfig):
128133
before_all: SQL statements or macros to be executed at the start of the `sqlmesh plan` and `sqlmesh run` commands.
129134
after_all: SQL statements or macros to be executed at the end of the `sqlmesh plan` and `sqlmesh run` commands.
130135
cache_dir: The directory to store the SQLMesh cache. Defaults to .cache in the project folder.
136+
dbt_config_info: Dbt-specific properties (such as profile and target) for dbt projects loaded by the dbt loader
131137
"""
132138

133139
gateways: GatewayDict = {"": GatewayConfig()}
@@ -137,6 +143,7 @@ class Config(BaseConfig):
137143
)
138144
default_scheduler: SchedulerConfig = BuiltInSchedulerConfig()
139145
default_gateway: str = ""
146+
state_schema_naming_pattern: t.Optional[str] = None
140147
notification_targets: t.List[NotificationTarget] = []
141148
project: str = ""
142149
snapshot_ttl: NoPastTTLString = c.DEFAULT_SNAPSHOT_TTL
@@ -173,6 +180,7 @@ class Config(BaseConfig):
173180
linter: LinterConfig = LinterConfig()
174181
janitor: JanitorConfig = JanitorConfig()
175182
cache_dir: t.Optional[str] = None
183+
dbt_config_info: t.Optional[DbtConfigInfo] = None
176184

177185
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
178186
"gateways": UpdateStrategy.NESTED_UPDATE,
@@ -344,8 +352,27 @@ def get_test_connection(
344352
def get_scheduler(self, gateway_name: t.Optional[str] = None) -> SchedulerConfig:
345353
return self.get_gateway(gateway_name).scheduler or self.default_scheduler
346354

347-
def get_state_schema(self, gateway_name: t.Optional[str] = None) -> t.Optional[str]:
348-
return self.get_gateway(gateway_name).state_schema
355+
def get_state_schema(self, gateway_name: t.Optional[str] = None) -> str:
356+
state_schema = self.get_gateway(gateway_name).state_schema
357+
358+
if state_schema is None and self.state_schema_naming_pattern:
359+
substitutions = {}
360+
if dbt := self.dbt_config_info:
361+
# TODO: keeping this simple for now rather than trying to set up a Jinja or SQLMesh Macro rendering context
362+
substitutions.update(
363+
{
364+
"@{dbt_profile_name}": dbt.profile_name,
365+
# TODO @iaroslav: what was the problem with using target name instead of the default schema name again?
366+
"@{dbt_target_name}": dbt.target_name,
367+
}
368+
)
369+
state_schema = self.state_schema_naming_pattern
370+
for pattern, value in substitutions.items():
371+
state_schema = state_schema.replace(pattern, value)
372+
373+
logger.info("Inferring state schema: %s", state_schema)
374+
375+
return state_schema or c.SQLMESH
349376

350377
@property
351378
def default_gateway_name(self) -> str:

sqlmesh/dbt/loader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ConnectionConfig,
1212
GatewayConfig,
1313
ModelDefaultsConfig,
14+
DbtConfigInfo,
1415
)
1516
from sqlmesh.core.environment import EnvironmentStatements
1617
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
@@ -67,10 +68,19 @@ def sqlmesh_config(
6768
if not issubclass(loader, DbtLoader):
6869
raise ConfigError("The loader must be a DbtLoader.")
6970

71+
if context.profile_name is None:
72+
# Note: Profile.load() mutates `context` and will have already raised an exception if profile_name is not set,
73+
# but mypy doesnt know this because the field is defined as t.Optional[str]
74+
raise ConfigError(f"profile name must be set")
75+
7076
return Config(
7177
loader=loader,
7278
model_defaults=model_defaults,
7379
variables=variables or {},
80+
dbt_config_info=DbtConfigInfo(
81+
profile_name=dbt_profile_name or context.profile_name,
82+
target_name=dbt_target_name or profile.target_name,
83+
),
7484
**{
7585
"default_gateway": profile.target_name if "gateways" not in kwargs else "",
7686
"gateways": {

tests/dbt/test_config.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sqlmesh.core.dialect import jinja_query
1616
from sqlmesh.core.model import SqlModel
1717
from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange
18+
from sqlmesh.core.state_sync import CachingStateSync, EngineAdapterStateSync
1819
from sqlmesh.dbt.builtin import Api
1920
from sqlmesh.dbt.column import ColumnConfig
2021
from sqlmesh.dbt.common import Dependencies
@@ -46,7 +47,8 @@
4647
)
4748
from sqlmesh.dbt.test import TestConfig
4849
from sqlmesh.utils.errors import ConfigError
49-
from sqlmesh.utils.yaml import load as yaml_load
50+
from sqlmesh.utils.yaml import load as yaml_load, dump as yaml_dump
51+
from tests.dbt.conftest import EmptyProjectCreator
5052

5153
pytestmark = pytest.mark.dbt
5254

@@ -1211,3 +1213,35 @@ def test_empty_vars_config(tmp_path):
12111213
# Verify the variables are empty (not causing any issues)
12121214
assert project.packages["test_empty_vars"].variables == {}
12131215
assert project.context.variables == {}
1216+
1217+
1218+
def test_state_schema_naming_pattern(create_empty_project: EmptyProjectCreator):
1219+
project_dir, _ = create_empty_project("test_foo", "dev")
1220+
1221+
# no state_schema_naming_pattern, creating python config manually doesnt take into account
1222+
# any config yaml files that may be present, so we get the default state schema
1223+
config = sqlmesh_config(project_root=project_dir)
1224+
assert not config.state_schema_naming_pattern
1225+
assert config.get_state_schema() == "sqlmesh"
1226+
1227+
# create_empty_project() uses the default dbt template for sqlmesh yaml config which
1228+
# sets state_schema_naming_pattern
1229+
ctx = Context(paths=[project_dir])
1230+
assert ctx.config.state_schema_naming_pattern
1231+
assert ctx.config.get_state_schema() == "sqlmesh_state_test_foo_dev"
1232+
assert isinstance(ctx.state_sync, CachingStateSync)
1233+
assert isinstance(ctx.state_sync.state_sync, EngineAdapterStateSync)
1234+
assert ctx.state_sync.state_sync.schema == "sqlmesh_state_test_foo_dev"
1235+
1236+
# If the user delberately overrides state_schema then we should respect this choice
1237+
config_file = project_dir / "sqlmesh.yaml"
1238+
config_yaml = yaml_load(config_file)
1239+
config_yaml["gateways"] = {"dev": {"state_schema": "state_override"}}
1240+
config_file.write_text(yaml_dump(config_yaml))
1241+
1242+
ctx = Context(paths=[project_dir])
1243+
assert ctx.config.state_schema_naming_pattern
1244+
assert ctx.config.get_state_schema() == "state_override"
1245+
assert isinstance(ctx.state_sync, CachingStateSync)
1246+
assert isinstance(ctx.state_sync.state_sync, EngineAdapterStateSync)
1247+
assert ctx.state_sync.state_sync.schema == "state_override"

tests/dbt/test_integration.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from sqlmesh.core.config.connection import DuckDBConnectionConfig
2020
from sqlmesh.core.engine_adapter import DuckDBEngineAdapter
2121
from sqlmesh.utils.pandas import columns_to_types_from_df
22-
from sqlmesh.utils.yaml import YAML
22+
from sqlmesh.utils.yaml import YAML, load as yaml_load, dump as yaml_dump
23+
from sqlmesh_dbt.operations import init_project_if_required
2324
from tests.utils.pandas import compare_dataframes, create_df
2425

2526
# Some developers had issues with this test freezing locally so we mark it as cicdonly
@@ -604,3 +605,41 @@ def test_dbt_node_info(jaffle_shop_duckdb_context: Context):
604605
relationship_audit.node.dbt_node_info.name
605606
== "relationships_orders_customer_id__customer_id__ref_customers_"
606607
)
608+
609+
610+
def test_state_schema_isolation_per_target(jaffle_shop_duckdb: Path):
611+
profiles_file = jaffle_shop_duckdb / "profiles.yml"
612+
613+
profiles_yml = yaml_load(profiles_file)
614+
615+
# make prod / dev config identical with the exception of a different default schema to simulate using the same warehouse
616+
profiles_yml["jaffle_shop"]["outputs"]["prod"] = {
617+
**profiles_yml["jaffle_shop"]["outputs"]["dev"]
618+
}
619+
profiles_yml["jaffle_shop"]["outputs"]["prod"]["schema"] = "prod_schema"
620+
profiles_yml["jaffle_shop"]["outputs"]["dev"]["schema"] = "dev_schema"
621+
622+
profiles_file.write_text(yaml_dump(profiles_yml))
623+
624+
init_project_if_required(jaffle_shop_duckdb)
625+
626+
# start off with the prod target
627+
prod_ctx = Context(paths=[jaffle_shop_duckdb], config_loader_kwargs={"target": "prod"})
628+
assert prod_ctx.config.get_state_schema() == "sqlmesh_state_jaffle_shop_prod"
629+
assert all("prod_schema" in fqn for fqn in prod_ctx.models)
630+
assert prod_ctx.plan(auto_apply=True).has_changes
631+
assert not prod_ctx.plan(auto_apply=True).has_changes
632+
633+
# dev target should have changes - new state separate from prod
634+
dev_ctx = Context(paths=[jaffle_shop_duckdb], config_loader_kwargs={"target": "dev"})
635+
assert dev_ctx.config.get_state_schema() == "sqlmesh_state_jaffle_shop_dev"
636+
assert all("dev_schema" in fqn for fqn in dev_ctx.models)
637+
assert dev_ctx.plan(auto_apply=True).has_changes
638+
assert not dev_ctx.plan(auto_apply=True).has_changes
639+
640+
# no explicitly specified target should use dev because that's what's set for the default in the profiles.yml
641+
assert profiles_yml["jaffle_shop"]["target"] == "dev"
642+
default_ctx = Context(paths=[jaffle_shop_duckdb])
643+
assert default_ctx.config.get_state_schema() == "sqlmesh_state_jaffle_shop_dev"
644+
assert all("dev_schema" in fqn for fqn in default_ctx.models)
645+
assert not default_ctx.plan(auto_apply=True).has_changes

tests/fixtures/dbt/empty_project/profiles.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ empty_project:
33
target: __DEFAULT_TARGET__
44

55
outputs:
6-
duckdb:
6+
__DEFAULT_TARGET__:
77
type: duckdb
88
path: 'empty_project.duckdb'
99
threads: 4

0 commit comments

Comments
 (0)