Skip to content

Commit 5ffed10

Browse files
Feat(dbt): Add support for transaction in dbt pre and post hooks (#5480)
1 parent 8e9fe23 commit 5ffed10

File tree

8 files changed

+347
-30
lines changed

8 files changed

+347
-30
lines changed

sqlmesh/core/model/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any:
663663

664664
class ParsableSql(PydanticModel):
665665
sql: str
666+
transaction: t.Optional[bool] = None
666667

667668
_parsed: t.Optional[exp.Expression] = None
668669
_parsed_dialect: t.Optional[str] = None

sqlmesh/core/model/definition.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def render_pre_statements(
363363
expand: t.Iterable[str] = tuple(),
364364
deployability_index: t.Optional[DeployabilityIndex] = None,
365365
engine_adapter: t.Optional[EngineAdapter] = None,
366+
inside_transaction: t.Optional[bool] = True,
366367
**kwargs: t.Any,
367368
) -> t.List[exp.Expression]:
368369
"""Renders pre-statements for a model.
@@ -384,7 +385,11 @@ def render_pre_statements(
384385
The list of rendered expressions.
385386
"""
386387
return self._render_statements(
387-
self.pre_statements,
388+
[
389+
stmt
390+
for stmt in self.pre_statements
391+
if stmt.args.get("transaction", True) == inside_transaction
392+
],
388393
start=start,
389394
end=end,
390395
execution_time=execution_time,
@@ -405,6 +410,7 @@ def render_post_statements(
405410
expand: t.Iterable[str] = tuple(),
406411
deployability_index: t.Optional[DeployabilityIndex] = None,
407412
engine_adapter: t.Optional[EngineAdapter] = None,
413+
inside_transaction: t.Optional[bool] = True,
408414
**kwargs: t.Any,
409415
) -> t.List[exp.Expression]:
410416
"""Renders post-statements for a model.
@@ -420,13 +426,18 @@ def render_post_statements(
420426
that depend on materialized tables. Model definitions are inlined and can thus be run end to
421427
end on the fly.
422428
deployability_index: Determines snapshots that are deployable in the context of this render.
429+
inside_transaction: Whether to render hooks with transaction=True (inside) or transaction=False (outside).
423430
kwargs: Additional kwargs to pass to the renderer.
424431
425432
Returns:
426433
The list of rendered expressions.
427434
"""
428435
return self._render_statements(
429-
self.post_statements,
436+
[
437+
stmt
438+
for stmt in self.post_statements
439+
if stmt.args.get("transaction", True) == inside_transaction
440+
],
430441
start=start,
431442
end=end,
432443
execution_time=execution_time,
@@ -567,6 +578,8 @@ def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]:
567578
result = []
568579
for v in value:
569580
parsed = v.parse(self.dialect)
581+
if getattr(v, "transaction", None) is not None:
582+
parsed.set("transaction", v.transaction)
570583
if not isinstance(parsed, exp.Semicolon):
571584
result.append(parsed)
572585
return result
@@ -2592,9 +2605,17 @@ def _create_model(
25922605
if statement_field in kwargs:
25932606
# Macros extracted from these statements need to be treated as metadata only
25942607
is_metadata = statement_field == "on_virtual_update"
2595-
statements.extend((stmt, is_metadata) for stmt in kwargs[statement_field])
2608+
for stmt in kwargs[statement_field]:
2609+
# Extract the expression if it's ParsableSql already
2610+
expr = stmt.parse(dialect) if isinstance(stmt, ParsableSql) else stmt
2611+
statements.append((expr, is_metadata))
25962612
kwargs[statement_field] = [
2597-
ParsableSql.from_parsed_expression(stmt, dialect, use_meta_sql=use_original_sql)
2613+
# this to retain the transaction information
2614+
stmt
2615+
if isinstance(stmt, ParsableSql)
2616+
else ParsableSql.from_parsed_expression(
2617+
stmt, dialect, use_meta_sql=use_original_sql
2618+
)
25982619
for stmt in kwargs[statement_field]
25992620
]
26002621

sqlmesh/core/snapshot/evaluator.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -750,13 +750,19 @@ def _evaluate_snapshot(
750750
**render_statements_kwargs
751751
)
752752

753+
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
754+
evaluation_strategy.run_pre_statements(
755+
snapshot=snapshot,
756+
render_kwargs={**render_statements_kwargs, "inside_transaction": False},
757+
)
758+
753759
with (
754760
adapter.transaction(),
755761
adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)),
756762
):
757-
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
758763
evaluation_strategy.run_pre_statements(
759-
snapshot=snapshot, render_kwargs=render_statements_kwargs
764+
snapshot=snapshot,
765+
render_kwargs={**render_statements_kwargs, "inside_transaction": True},
760766
)
761767

762768
if not target_table_exists or (model.is_seed and not snapshot.intervals):
@@ -828,10 +834,16 @@ def _evaluate_snapshot(
828834
)
829835

830836
evaluation_strategy.run_post_statements(
831-
snapshot=snapshot, render_kwargs=render_statements_kwargs
837+
snapshot=snapshot,
838+
render_kwargs={**render_statements_kwargs, "inside_transaction": True},
832839
)
833840

834-
return wap_id
841+
evaluation_strategy.run_post_statements(
842+
snapshot=snapshot,
843+
render_kwargs={**render_statements_kwargs, "inside_transaction": False},
844+
)
845+
846+
return wap_id
835847

836848
def create_snapshot(
837849
self,
@@ -865,6 +877,11 @@ def create_snapshot(
865877
deployability_index=deployability_index,
866878
)
867879

880+
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
881+
evaluation_strategy.run_pre_statements(
882+
snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False}
883+
)
884+
868885
with (
869886
adapter.transaction(),
870887
adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)),
@@ -896,6 +913,10 @@ def create_snapshot(
896913
dry_run=True,
897914
)
898915

916+
evaluation_strategy.run_post_statements(
917+
snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False}
918+
)
919+
899920
if on_complete is not None:
900921
on_complete(snapshot)
901922

@@ -1097,6 +1118,11 @@ def _migrate_snapshot(
10971118
)
10981119
target_table_name = snapshot.table_name()
10991120

1121+
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
1122+
evaluation_strategy.run_pre_statements(
1123+
snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False}
1124+
)
1125+
11001126
with (
11011127
adapter.transaction(),
11021128
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
@@ -1134,6 +1160,10 @@ def _migrate_snapshot(
11341160
dry_run=True,
11351161
)
11361162

1163+
evaluation_strategy.run_post_statements(
1164+
snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False}
1165+
)
1166+
11371167
# Retry in case when the table is migrated concurrently from another plan application
11381168
@retry(
11391169
reraise=True,
@@ -1454,7 +1484,8 @@ def _execute_create(
14541484
}
14551485
if run_pre_post_statements:
14561486
evaluation_strategy.run_pre_statements(
1457-
snapshot=snapshot, render_kwargs=create_render_kwargs
1487+
snapshot=snapshot,
1488+
render_kwargs={**create_render_kwargs, "inside_transaction": True},
14581489
)
14591490
evaluation_strategy.create(
14601491
table_name=table_name,
@@ -1471,7 +1502,8 @@ def _execute_create(
14711502
)
14721503
if run_pre_post_statements:
14731504
evaluation_strategy.run_post_statements(
1474-
snapshot=snapshot, render_kwargs=create_render_kwargs
1505+
snapshot=snapshot,
1506+
render_kwargs={**create_render_kwargs, "inside_transaction": True},
14751507
)
14761508

14771509
def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool:
@@ -2944,12 +2976,20 @@ def append(
29442976
)
29452977

29462978
def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
2947-
# in dbt custom materialisations it's up to the user when to run the pre hooks
2948-
pass
2979+
# in dbt custom materialisations it's up to the user to run the pre hooks inside the transaction
2980+
if not render_kwargs.get("inside_transaction", True):
2981+
super().run_pre_statements(
2982+
snapshot=snapshot,
2983+
render_kwargs=render_kwargs,
2984+
)
29492985

29502986
def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
2951-
# in dbt custom materialisations it's up to the user when to run the post hooks
2952-
pass
2987+
# in dbt custom materialisations it's up to the user to run the post hooks inside the transaction
2988+
if not render_kwargs.get("inside_transaction", True):
2989+
super().run_post_statements(
2990+
snapshot=snapshot,
2991+
render_kwargs=render_kwargs,
2992+
)
29532993

29542994
def _execute_materialization(
29552995
self,
@@ -2985,14 +3025,15 @@ def _execute_materialization(
29853025
"sql": str(query_or_df),
29863026
"is_first_insert": is_first_insert,
29873027
"create_only": create_only,
2988-
# FIXME: Add support for transaction=False
29893028
"pre_hooks": [
2990-
AttributeDict({"sql": s.this.this, "transaction": True})
3029+
AttributeDict({"sql": s.this.this, "transaction": transaction})
29913030
for s in model.pre_statements
3031+
if (transaction := s.args.get("transaction", True))
29923032
],
29933033
"post_hooks": [
2994-
AttributeDict({"sql": s.this.this, "transaction": True})
3034+
AttributeDict({"sql": s.this.this, "transaction": transaction})
29953035
for s in model.post_statements
3036+
if (transaction := s.args.get("transaction", True))
29963037
],
29973038
"model_instance": model,
29983039
**kwargs,

sqlmesh/dbt/basemodel.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sqlmesh.core.config.base import UpdateStrategy
1414
from sqlmesh.core.config.common import VirtualEnvironmentMode
1515
from sqlmesh.core.model import Model
16+
from sqlmesh.core.model.common import ParsableSql
1617
from sqlmesh.core.node import DbtNodeInfo
1718
from sqlmesh.dbt.column import (
1819
ColumnConfig,
@@ -87,7 +88,7 @@ class Hook(DbtConfig):
8788
"""
8889

8990
sql: SqlStr
90-
transaction: bool = True # TODO not yet supported
91+
transaction: bool = True
9192

9293
_sql_validator = sql_str_validator
9394

@@ -339,8 +340,14 @@ def sqlmesh_model_kwargs(
339340
),
340341
"jinja_macros": jinja_macros,
341342
"path": self.path,
342-
"pre_statements": [d.jinja_statement(hook.sql) for hook in self.pre_hook],
343-
"post_statements": [d.jinja_statement(hook.sql) for hook in self.post_hook],
343+
"pre_statements": [
344+
ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction)
345+
for hook in self.pre_hook
346+
],
347+
"post_statements": [
348+
ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction)
349+
for hook in self.post_hook
350+
],
344351
"tags": self.tags,
345352
"physical_schema_mapping": context.sqlmesh_config.physical_schema_mapping,
346353
"default_catalog": context.target.database,

tests/core/test_snapshot_evaluator.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3232,11 +3232,11 @@ def test_create_post_statements_use_non_deployable_table(
32323232
evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable())
32333233

32343234
call_args = adapter_mock.execute.call_args_list
3235-
pre_calls = call_args[0][0][0]
3235+
pre_calls = call_args[1][0][0]
32363236
assert len(pre_calls) == 1
32373237
assert pre_calls[0].sql(dialect="postgres") == expected_call
32383238

3239-
post_calls = call_args[1][0][0]
3239+
post_calls = call_args[2][0][0]
32403240
assert len(post_calls) == 1
32413241
assert post_calls[0].sql(dialect="postgres") == expected_call
32423242

@@ -3294,11 +3294,11 @@ def model_with_statements(context, **kwargs):
32943294
expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}__dev" /* db.test_model */("id")'
32953295

32963296
call_args = adapter_mock.execute.call_args_list
3297-
pre_calls = call_args[0][0][0]
3297+
pre_calls = call_args[1][0][0]
32983298
assert len(pre_calls) == 1
32993299
assert pre_calls[0].sql(dialect="postgres") == expected_call
33003300

3301-
post_calls = call_args[1][0][0]
3301+
post_calls = call_args[2][0][0]
33023302
assert len(post_calls) == 1
33033303
assert post_calls[0].sql(dialect="postgres") == expected_call
33043304

@@ -3356,14 +3356,14 @@ def create_log_table(evaluator, view_name):
33563356
)
33573357

33583358
call_args = adapter_mock.execute.call_args_list
3359-
post_calls = call_args[1][0][0]
3359+
post_calls = call_args[2][0][0]
33603360
assert len(post_calls) == 1
33613361
assert (
33623362
post_calls[0].sql(dialect="postgres")
33633363
== f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a")'
33643364
)
33653365

3366-
on_virtual_update_calls = call_args[2][0][0]
3366+
on_virtual_update_calls = call_args[4][0][0]
33673367
assert (
33683368
on_virtual_update_calls[0].sql(dialect="postgres")
33693369
== 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"'
@@ -3441,7 +3441,7 @@ def model_with_statements(context, **kwargs):
34413441
)
34423442

34433443
call_args = adapter_mock.execute.call_args_list
3444-
on_virtual_update_call = call_args[2][0][0][0]
3444+
on_virtual_update_call = call_args[4][0][0][0]
34453445
assert (
34463446
on_virtual_update_call.sql(dialect="postgres")
34473447
== 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")'
@@ -4187,11 +4187,11 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot):
41874187
assert view_args[1][0][0] == "test_schema__test_env.test_model"
41884188

41894189
call_args = engine_adapters["secondary"].execute.call_args_list
4190-
pre_calls = call_args[0][0][0]
4190+
pre_calls = call_args[1][0][0]
41914191
assert len(pre_calls) == 1
41924192
assert pre_calls[0].sql(dialect="postgres") == expected_call
41934193

4194-
post_calls = call_args[1][0][0]
4194+
post_calls = call_args[2][0][0]
41954195
assert len(post_calls) == 1
41964196
assert post_calls[0].sql(dialect="postgres") == expected_call
41974197

@@ -4459,7 +4459,7 @@ def model_with_statements(context, **kwargs):
44594459

44604460
# For the pre/post statements verify the model-specific gateway was used
44614461
engine_adapters["default"].execute.assert_called_once()
4462-
assert len(engine_adapters["secondary"].execute.call_args_list) == 2
4462+
assert len(engine_adapters["secondary"].execute.call_args_list) == 4
44634463

44644464
# Validate that the get_catalog_type method was called only on the secondary engine from the macro evaluator
44654465
engine_adapters["default"].get_catalog_type.assert_not_called()

0 commit comments

Comments
 (0)