Skip to content

Commit 7f5adcf

Browse files
Feat(dbt): Add config object to provide methods aligned with dbt (#5271)
1 parent cabbd5c commit 7f5adcf

File tree

2 files changed

+169
-1
lines changed

2 files changed

+169
-1
lines changed

sqlmesh/dbt/builtin.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,58 @@ def has_var(self, name: str) -> bool:
164164
return name in self.variables
165165

166166

167+
class Config:
168+
def __init__(self, config_dict: t.Dict[str, t.Any]) -> None:
169+
self._config = config_dict
170+
171+
def __call__(self, **kwargs: t.Any) -> str:
172+
self._config.update(**kwargs)
173+
return ""
174+
175+
def set(self, name: str, value: t.Any) -> str:
176+
self._config.update({name: value})
177+
return ""
178+
179+
def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None:
180+
try:
181+
validator(value)
182+
except Exception as e:
183+
raise ConfigError(f"Config validation failed for '{name}': {e}")
184+
185+
def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any:
186+
if name not in self._config:
187+
raise ConfigError(f"Missing required config: {name}")
188+
189+
value = self._config[name]
190+
191+
if validator is not None:
192+
self._validate(name, validator, value)
193+
194+
return value
195+
196+
def get(
197+
self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None
198+
) -> t.Any:
199+
value = self._config.get(name, default)
200+
201+
if validator is not None and value is not None:
202+
self._validate(name, validator, value)
203+
204+
return value
205+
206+
def persist_relation_docs(self) -> bool:
207+
persist_docs = self.get("persist_docs", default={})
208+
if not isinstance(persist_docs, dict):
209+
return False
210+
return persist_docs.get("relation", False)
211+
212+
def persist_column_docs(self) -> bool:
213+
persist_docs = self.get("persist_docs", default={})
214+
if not isinstance(persist_docs, dict):
215+
return False
216+
return persist_docs.get("columns", False)
217+
218+
167219
def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]:
168220
if name not in os.environ and default is None:
169221
raise ConfigError(f"Missing environment variable '{name}'")
@@ -395,6 +447,8 @@ def create_builtin_globals(
395447
if variables is not None:
396448
builtin_globals["var"] = Var(variables)
397449

450+
builtin_globals["config"] = Config(jinja_globals.pop("config", {}))
451+
398452
deployability_index = (
399453
jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable()
400454
)

tests/dbt/test_transformation.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def test_schema_jinja(sushi_test_project: Project, assert_exp_eq):
946946

947947
@pytest.mark.xdist_group("dbt_manifest")
948948
def test_config_jinja(sushi_test_project: Project):
949-
hook = "{{ config(alias='bar') }} {{ config.alias }}"
949+
hook = "{{ config(alias='bar') }} {{ config.get('alias') }}"
950950
model_config = ModelConfig(
951951
name="model",
952952
package_name="package",
@@ -961,6 +961,120 @@ def test_config_jinja(sushi_test_project: Project):
961961
assert model.render_pre_statements()[0].sql() == '"bar"'
962962

963963

964+
@pytest.mark.xdist_group("dbt_manifest")
965+
def test_config_jinja_get_methods(sushi_test_project: Project):
966+
model_config = ModelConfig(
967+
name="model_conf",
968+
package_name="package",
969+
schema="sushi",
970+
sql="""SELECT 1 AS one FROM foo""",
971+
alias="model_alias",
972+
**{
973+
"pre-hook": [
974+
"{{ config(materialized='incremental', unique_key='id') }}"
975+
"{{ config.get('missed', 'a') + config.get('missed', default='b')}}",
976+
"{{ config.set('alias', 'new_alias')}}",
977+
"{{ config.get('package_name') + '_' + config.require('unique_key')}}",
978+
"{{ config.get('alias') or 'default'}}",
979+
]
980+
},
981+
**{"post-hook": "{{config.require('missing_key')}}"},
982+
)
983+
context = sushi_test_project.context
984+
model = t.cast(SqlModel, model_config.to_sqlmesh(context))
985+
986+
assert model.render_pre_statements()[0].sql() == '"ab"'
987+
assert model.render_pre_statements()[1].sql() == '"package_id"'
988+
assert model.render_pre_statements()[2].sql() == '"new_alias"'
989+
990+
with pytest.raises(ConfigError, match="Missing required config: missing_key"):
991+
model.render_post_statements()
992+
993+
# test get methods with operations
994+
model_2_config = ModelConfig(
995+
name="model_2",
996+
package_name="package",
997+
schema="sushi",
998+
sql="""SELECT 1 AS one FROM foo""",
999+
alias="mod",
1000+
materialized="table",
1001+
threads=8,
1002+
partition_by="date",
1003+
cluster_by=["user_id", "product_id"],
1004+
**{
1005+
"pre-hook": [
1006+
"{{ config.get('partition_by', default='none') }}",
1007+
"{{ config.get('cluster_by', default=[]) | length }}",
1008+
"{% if config.get('threads') > 4 %}high_threads{% else %}low_threads{% endif %}",
1009+
]
1010+
},
1011+
)
1012+
model2 = t.cast(SqlModel, model_2_config.to_sqlmesh(context))
1013+
1014+
pre_statements2 = model2.render_pre_statements()
1015+
assert pre_statements2[0].sql() == "ARRAY('date')"
1016+
assert pre_statements2[1].sql() == "2"
1017+
assert pre_statements2[2].sql() == '"high_threads"'
1018+
1019+
# test seting variable and conditional
1020+
model_invalid_timeout = ModelConfig(
1021+
name="invalid_timeout_test",
1022+
package_name="package",
1023+
schema="sushi",
1024+
sql="""SELECT 1 AS one FROM foo""",
1025+
alias="invalid_timeout_alias",
1026+
connection_timeout=44,
1027+
**{
1028+
"pre-hook": [
1029+
"""
1030+
{%- set value = config.require('connection_timeout') -%}
1031+
{%- set is_valid = value >= 10 and value <= 30 -%}
1032+
{%- if not is_valid -%}
1033+
{{ exceptions.raise_compiler_error("Validation failed for 'connection_timeout': Value must be between 10 and 30, got: " ~ value) }}
1034+
{%- endif -%}
1035+
{{ value }}
1036+
""",
1037+
]
1038+
},
1039+
)
1040+
1041+
model_invalid = t.cast(SqlModel, model_invalid_timeout.to_sqlmesh(context))
1042+
with pytest.raises(
1043+
ConfigError,
1044+
match="Validation failed for 'connection_timeout': Value must be between 10 and 30, got: 44",
1045+
):
1046+
model_invalid.render_pre_statements()
1047+
1048+
# test persist_docs methods
1049+
model_config_persist = ModelConfig(
1050+
name="persist_docs_model",
1051+
package_name="package",
1052+
schema="sushi",
1053+
sql="""SELECT 1 AS one FROM foo""",
1054+
alias="persist_alias",
1055+
**{
1056+
"pre-hook": [
1057+
"{{ config(persist_docs={'relation': true, 'columns': true}) }}",
1058+
"{{ config.persist_relation_docs() }}",
1059+
"{{ config.persist_column_docs() }}",
1060+
"{{ config(persist_docs={'relation': false, 'columns': true}) }}",
1061+
"{{ config.persist_relation_docs() }}",
1062+
"{{ config.persist_column_docs() }}",
1063+
]
1064+
},
1065+
)
1066+
model3 = t.cast(SqlModel, model_config_persist.to_sqlmesh(context))
1067+
1068+
pre_statements3 = model3.render_pre_statements()
1069+
1070+
# it should filter out empty returns, so we get 4 statements
1071+
assert len(pre_statements3) == 4
1072+
assert pre_statements3[0].sql() == "TRUE"
1073+
assert pre_statements3[1].sql() == "TRUE"
1074+
assert pre_statements3[2].sql() == "FALSE"
1075+
assert pre_statements3[3].sql() == "TRUE"
1076+
1077+
9641078
@pytest.mark.xdist_group("dbt_manifest")
9651079
def test_model_this(assert_exp_eq, sushi_test_project: Project):
9661080
model_config = ModelConfig(

0 commit comments

Comments
 (0)