Skip to content

Commit e0f5dcc

Browse files
treyspizeigerman
andauthored
Feat: enable dbt transpilation (#2806)
Co-authored-by: Iaroslav Zeigerman <zeigerman.ia@gmail.com>
1 parent 14ad71d commit e0f5dcc

File tree

16 files changed

+215
-48
lines changed

16 files changed

+215
-48
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ def __init__(
2929
self,
3030
jinja_macros: JinjaMacroRegistry,
3131
jinja_globals: t.Optional[t.Dict[str, t.Any]] = None,
32-
dialect: t.Optional[str] = None,
32+
project_dialect: t.Optional[str] = None,
3333
):
3434
self.jinja_macros = jinja_macros
3535
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
3636
self.jinja_globals["adapter"] = self
37-
self.dialect = dialect
37+
self.project_dialect = project_dialect
3838

3939
@abc.abstractmethod
4040
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
@@ -93,7 +93,7 @@ def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
9393

9494
def quote(self, identifier: str) -> str:
9595
"""Returns a quoted identifier."""
96-
return exp.to_column(identifier).sql(dialect=self.dialect, identify=True)
96+
return exp.to_column(identifier).sql(dialect=self.project_dialect, identify=True)
9797

9898
def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
9999
"""Returns a dialect-specific version of a macro with the given name."""
@@ -189,12 +189,17 @@ def __init__(
189189
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
190190
table_mapping: t.Optional[t.Dict[str, str]] = None,
191191
deployability_index: t.Optional[DeployabilityIndex] = None,
192+
project_dialect: t.Optional[str] = None,
192193
):
193194
from dbt.adapters.base import BaseRelation
194195
from dbt.adapters.base.column import Column
195196
from dbt.adapters.base.relation import Policy
196197

197-
super().__init__(jinja_macros, jinja_globals=jinja_globals, dialect=engine_adapter.dialect)
198+
super().__init__(
199+
jinja_macros,
200+
jinja_globals=jinja_globals,
201+
project_dialect=project_dialect or engine_adapter.dialect,
202+
)
198203

199204
table_mapping = table_mapping or {}
200205

@@ -259,7 +264,9 @@ def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]:
259264

260265
mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation)))
261266
return [
262-
Column.from_description(name=name, raw_data_type=dtype.sql(dialect=self.dialect))
267+
Column.from_description(
268+
name=name, raw_data_type=dtype.sql(dialect=self.project_dialect)
269+
)
263270
for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items()
264271
]
265272

@@ -298,12 +305,12 @@ def execute(
298305
self.engine_adapter.fetchdf if fetch else self.engine_adapter.execute # type: ignore
299306
)
300307

301-
expression = parse_one(sql, read=self.dialect)
308+
expression = parse_one(sql, read=self.project_dialect)
302309
with normalize_and_quote(
303-
expression, t.cast(str, self.dialect), self.engine_adapter.default_catalog
310+
expression, t.cast(str, self.project_dialect), self.engine_adapter.default_catalog
304311
) as expression:
305312
expression = exp.replace_tables(
306-
expression, self.table_mapping, dialect=self.dialect, copy=False
313+
expression, self.table_mapping, dialect=self.project_dialect, copy=False
307314
)
308315

309316
if auto_begin:
@@ -328,17 +335,17 @@ def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
328335
return identifier if identifier else None
329336

330337
def _map_table_name(self, table: exp.Table) -> exp.Table:
331-
name = table.sql()
338+
name = table.sql(dialect=self.project_dialect)
332339
physical_table_name = self.table_mapping.get(name)
333340
if not physical_table_name:
334341
return table
335342

336343
logger.debug("Resolved ref '%s' to snapshot table '%s'", name, physical_table_name)
337344

338-
return exp.to_table(physical_table_name, dialect=self.dialect)
345+
return exp.to_table(physical_table_name, dialect=self.project_dialect)
339346

340347
def _relation_to_table(self, relation: BaseRelation) -> exp.Table:
341-
return exp.to_table(relation.render(), dialect=self.dialect)
348+
return exp.to_table(relation.render(), dialect=self.project_dialect)
342349

343350
def _table_to_relation(self, table: exp.Table) -> BaseRelation:
344351
return self.relation_type.create(
@@ -358,7 +365,7 @@ def _schema(self, schema_relation: BaseRelation) -> exp.Table:
358365

359366
def _normalize(self, input_table: exp.Table) -> exp.Table:
360367
normalized_name = normalize_model_name(
361-
input_table, self.engine_adapter.default_catalog, self.dialect
368+
input_table, self.engine_adapter.default_catalog, self.project_dialect
362369
)
363370
normalized_table = exp.to_table(normalized_name)
364371
if not input_table.this:

sqlmesh/dbt/basemodel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class BaseModelConfig(GeneralConfig):
105105
path: Path = Path()
106106
dependencies: Dependencies = Dependencies()
107107
tests: t.List[TestConfig] = []
108+
dialect_: t.Optional[str] = Field(None, alias="dialect")
108109

109110
# DBT configuration fields
110111
name: str = ""
@@ -187,6 +188,9 @@ def config_name(self) -> str:
187188
"""
188189
return f"{self.package_name}.{self.name}"
189190

191+
def dialect(self, context: DbtContext) -> str:
192+
return self.dialect_ or context.default_dialect
193+
190194
def canonical_name(self, context: DbtContext) -> str:
191195
"""
192196
Get the sqlmesh model name
@@ -296,7 +300,7 @@ def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]:
296300
)
297301
return {
298302
"audits": [(test.name, {}) for test in self.tests],
299-
"columns": column_types_to_sqlmesh(self.columns, context.dialect) or None,
303+
"columns": column_types_to_sqlmesh(self.columns, self.dialect(context)) or None,
300304
"column_descriptions": column_descriptions_to_sqlmesh(self.columns) or None,
301305
"depends_on": {
302306
model.canonical_name(context) for model in model_context.refs.values()

sqlmesh/dbt/builtin.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def warn(self, msg: str) -> str:
4343

4444

4545
class Api:
46-
def __init__(self, target: t.Optional[AttributeDict] = None) -> None:
47-
if target:
48-
config_class = TARGET_TYPE_TO_CONFIG_CLASS[target["type"]]
46+
def __init__(self, dialect: t.Optional[str]) -> None:
47+
if dialect:
48+
config_class = TARGET_TYPE_TO_CONFIG_CLASS[dialect]
4949
self.Relation = config_class.relation_class
5050
self.Column = config_class.column_class
5151
self.quote_policy = config_class.quote_policy
@@ -301,8 +301,10 @@ def create_builtin_globals(
301301
jinja_globals = jinja_globals.copy()
302302

303303
target: t.Optional[AttributeDict] = jinja_globals.get("target", None)
304-
api = Api(target)
305-
dialect = target.dialect if target else None # type: ignore
304+
project_dialect = jinja_globals.pop("dialect", None) or (
305+
target.get("dialect") if target else None
306+
)
307+
api = Api(project_dialect)
306308

307309
builtin_globals["api"] = api
308310

@@ -349,13 +351,14 @@ def create_builtin_globals(
349351
snapshots=jinja_globals.get("snapshots", {}),
350352
table_mapping=jinja_globals.get("table_mapping", {}),
351353
deployability_index=jinja_globals.get("deployability_index"),
354+
project_dialect=project_dialect,
352355
)
353356
else:
354357
builtin_globals["flags"] = Flags(which="parse")
355358
adapter = ParsetimeAdapter(
356359
jinja_macros,
357360
jinja_globals={**builtin_globals, **jinja_globals},
358-
dialect=dialect,
361+
project_dialect=project_dialect,
359362
)
360363

361364
sql_execution = SQLExecution(adapter)

sqlmesh/dbt/context.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ class DbtContext:
5959
_manifest: t.Optional[ManifestHelper] = None
6060

6161
@property
62-
def dialect(self) -> str:
62+
def default_dialect(self) -> str:
63+
if self.sqlmesh_config.dialect:
64+
return self.sqlmesh_config.dialect
6365
if not self.target:
64-
raise SQLMeshError("Target must be configured before calling the dialect property.")
66+
raise SQLMeshError(
67+
"Target must be configured before calling the default_dialect property."
68+
)
6569
return self.target.dialect
6670

6771
@property
@@ -229,6 +233,9 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]:
229233
output["project_name"] = self.project_name
230234
if self._target is not None:
231235
output["target"] = self._target.attribute_dict()
236+
# pass user-specified default dialect if we have already loaded the config
237+
if self.sqlmesh_config.dialect:
238+
output["dialect"] = self.sqlmesh_config.dialect
232239
return output
233240

234241
def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext:

sqlmesh/dbt/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import typing as t
55
from pathlib import Path
6-
76
from sqlmesh.core import constants as c
87
from sqlmesh.core.audit import Audit
98
from sqlmesh.core.config import (
@@ -43,7 +42,8 @@ def sqlmesh_config(
4342
context = DbtContext(project_root=project_root)
4443
profile = Profile.load(context, target_name=dbt_target_name)
4544
model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig())
46-
model_defaults.dialect = profile.target.dialect
45+
if model_defaults.dialect is None:
46+
model_defaults.dialect = profile.target.dialect
4747

4848
target_to_sqlmesh_args = {}
4949
if register_comments is not None:

sqlmesh/dbt/model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def model_kind(self, context: DbtContext) -> ModelKind:
201201
if materialization == Materialization.VIEW:
202202
return ViewKind()
203203
if materialization == Materialization.INCREMENTAL:
204-
incremental_materialization_kwargs: t.Dict[str, t.Any] = {"dialect": context.dialect}
204+
incremental_materialization_kwargs: t.Dict[str, t.Any] = {
205+
"dialect": self.dialect(context)
206+
}
205207
for field in ("batch_size", "lookback", "forward_only"):
206208
field_val = getattr(self, field, None) or self.meta.get(field, None)
207209
if field_val:
@@ -277,7 +279,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
277279
f"{self.canonical_name(context)}: SQLMesh snapshot strategy is required for snapshot materialization."
278280
)
279281
shared_kwargs = {
280-
"dialect": context.dialect,
282+
"dialect": self.dialect(context),
281283
"unique_key": self.unique_key,
282284
"invalidate_hard_deletes": self.invalidate_hard_deletes,
283285
"valid_from_name": "dbt_valid_from",
@@ -358,14 +360,14 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
358360

359361
def to_sqlmesh(self, context: DbtContext) -> Model:
360362
"""Converts the dbt model into a SQLMesh model."""
361-
dialect = context.dialect
363+
model_dialect = self.dialect(context)
362364
query = d.jinja_query(self.sql_no_config)
363365

364366
optional_kwargs: t.Dict[str, t.Any] = {}
365367

366368
if self.partition_by:
367369
optional_kwargs["partitioned_by"] = (
368-
[exp.to_column(val) for val in self.partition_by]
370+
[exp.to_column(val, dialect=model_dialect) for val in self.partition_by]
369371
if isinstance(self.partition_by, list)
370372
else self._big_query_partition_by_expr(context)
371373
)
@@ -374,7 +376,7 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
374376
clustered_by = []
375377
for c in self.cluster_by:
376378
try:
377-
clustered_by.append(d.parse_one(c, dialect=dialect).name)
379+
clustered_by.append(d.parse_one(c, dialect=model_dialect).name)
378380
except SqlglotError as e:
379381
raise ConfigError(f"Failed to parse cluster_by field '{c}': {e}") from e
380382
optional_kwargs["clustered_by"] = clustered_by
@@ -403,7 +405,7 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
403405
return create_sql_model(
404406
self.canonical_name(context),
405407
query,
406-
dialect=dialect,
408+
dialect=model_dialect,
407409
kind=self.model_kind(context),
408410
start=self.start,
409411
**optional_kwargs,

sqlmesh/dbt/seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
5050
return create_seed_model(
5151
self.canonical_name(context),
5252
SeedKind(path=seed_path),
53-
dialect=context.dialect,
53+
dialect=self.dialect(context),
5454
**kwargs,
5555
)
5656

sqlmesh/dbt/target.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,5 +852,6 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
852852
"snowflake": SnowflakeConfig,
853853
"bigquery": BigQueryConfig,
854854
"sqlserver": MSSQLConfig,
855+
"tsql": MSSQLConfig,
855856
"trino": TrinoConfig,
856857
}

sqlmesh/dbt/test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77

88
from pydantic import Field
9-
109
import sqlmesh.core.dialect as d
1110
from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit
1211
from sqlmesh.dbt.common import (
@@ -48,6 +47,7 @@ class TestConfig(GeneralConfig):
4847
interval_unit: The duration of an interval for the audit. By default, it is computed from the cron expression.
4948
column_name: The name of the column under test.
5049
dependencies: The macros, refs, and sources the test depends upon.
50+
dialect: SQL dialect of the test query.
5151
package_name: Name of the package that defines the test.
5252
alias: The alias for the materialized table where failures are stored (Not supported).
5353
schema: The schema for the materialized table where the failures are stored (Not supported).
@@ -73,6 +73,7 @@ class TestConfig(GeneralConfig):
7373
interval_unit: t.Optional[str] = None
7474
column_name: t.Optional[str] = None
7575
dependencies: Dependencies = Dependencies()
76+
dialect_: t.Optional[str] = Field(None, alias="dialect")
7677

7778
# dbt fields
7879
package_name: str = ""
@@ -110,6 +111,9 @@ def is_standalone(self) -> bool:
110111
def sqlmesh_config_fields(self) -> t.Set[str]:
111112
return {"description", "owner", "stamp", "cron", "interval_unit"}
112113

114+
def dialect(self, context: DbtContext) -> str:
115+
return self.dialect_ or context.default_dialect
116+
113117
def to_sqlmesh(self, context: DbtContext) -> Audit:
114118
"""Convert dbt Test to SQLMesh Audit
115119
@@ -142,7 +146,7 @@ def to_sqlmesh(self, context: DbtContext) -> Audit:
142146
jinja_macros.add_globals({"this": self.relation_info})
143147
audit = StandaloneAudit(
144148
name=self.name,
145-
dialect=context.dialect,
149+
dialect=self.dialect(context),
146150
skip=skip,
147151
query=query,
148152
jinja_macros=jinja_macros,
@@ -158,7 +162,7 @@ def to_sqlmesh(self, context: DbtContext) -> Audit:
158162
else:
159163
audit = ModelAudit(
160164
name=self.name,
161-
dialect=context.dialect,
165+
dialect=self.dialect(context),
162166
skip=skip,
163167
blocking=blocking,
164168
query=query,

tests/dbt/test_adapter.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,23 @@ def test_normalization(
7373
context = sushi_test_project.context
7474
assert context.target
7575

76+
# bla and bob will be normalized to lowercase since the target is duckdb
7677
adapter_mock = mocker.MagicMock()
77-
adapter_mock.dialect = "snowflake"
7878
adapter_mock.default_catalog = "test"
79+
adapter_mock.dialect = "duckdb"
80+
81+
duckdb_renderer = runtime_renderer(context, engine_adapter=adapter_mock)
82+
83+
schema_bla = schema_("bla", "test", quoted=True)
84+
relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True)
7985

86+
duckdb_renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}")
87+
adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)])
88+
89+
# bla and bob will be normalized to uppercase since the target is Snowflake, even though the default dialect is duckdb
90+
adapter_mock = mocker.MagicMock()
91+
adapter_mock.default_catalog = "test"
92+
adapter_mock.dialect = "snowflake"
8093
context.target = SnowflakeConfig(
8194
account="test",
8295
user="test",
@@ -85,16 +98,19 @@ def test_normalization(
8598
database="test",
8699
schema="test",
87100
)
88-
89101
renderer = runtime_renderer(context, engine_adapter=adapter_mock)
90102

91-
# bla and bob will be normalized to uppercase since we're dealing with Snowflake
92-
schema_bla = schema_("BLA", "TEST", quoted=True)
93-
relation_bla_bob = exp.table_("BOB", db="BLA", catalog="TEST", quoted=True)
103+
schema_bla = schema_("bla", "test", quoted=True)
104+
relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True)
94105

95106
renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}")
96107
adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)])
97108

109+
renderer("{{ adapter.get_relation(database='custom_db', schema='bla', identifier='bob') }}")
110+
adapter_mock.table_exists.assert_has_calls(
111+
[call(exp.table_("bob", db="bla", catalog="custom_db", quoted=True))]
112+
)
113+
98114
renderer(
99115
"{%- set relation = api.Relation.create(schema='bla') -%}"
100116
"{{ adapter.create_schema(relation) }}"
@@ -114,7 +130,7 @@ def test_normalization(
114130
adapter_mock.drop_table.assert_has_calls([call(relation_bla_bob)])
115131

116132
expected_star_query: exp.Select = exp.maybe_parse(
117-
'SELECT * FROM "T" as "T"', dialect="snowflake"
133+
'SELECT * FROM "t" as "t"', dialect="snowflake"
118134
)
119135

120136
# The following call to run_query won't return dataframes and so we're expected to

0 commit comments

Comments
 (0)