Skip to content

Commit 09d2bbd

Browse files
authored
Merge branch 'main' into fabric_alter_table_no_op
2 parents 46fe7e8 + f7f5af9 commit 09d2bbd

File tree

13 files changed

+354
-30
lines changed

13 files changed

+354
-30
lines changed

.circleci/continue_config.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,13 @@ jobs:
9393
- run:
9494
name: Run linters and code style checks
9595
command: make py-style
96-
- run:
97-
name: Exercise the benchmarks
98-
command: make benchmark-ci
96+
- unless:
97+
condition:
98+
equal: ["3.9", << parameters.python_version >>]
99+
steps:
100+
- run:
101+
name: Exercise the benchmarks
102+
command: make benchmark-ci
99103
- run:
100104
name: Run cicd tests
101105
command: make cicd-test

docs/reference/model_configuration.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ Configuration options for [`SCD_TYPE_2` models](../concepts/models/model_kinds.m
282282
| `unique_key` | The model column(s) containing each row's unique key | array[str] | Y |
283283
| `valid_from_name` | The model column containing each row's valid from date. (Default: `valid_from`) | str | N |
284284
| `valid_to_name` | The model column containing each row's valid to date. (Default: `valid_to`) | str | N |
285-
| `invalidate_hard_deletes` | If set to true, when a record is missing from the source table it will be marked as invalid - see [here](../concepts/models/model_kinds.md#deletes) for more information. (Default: `True`) | bool | N |
285+
| `invalidate_hard_deletes` | If set to true, when a record is missing from the source table it will be marked as invalid - see [here](../concepts/models/model_kinds.md#deletes) for more information. (Default: `False`) | bool | N |
286286

287287
##### SCD Type 2 By Time
288288

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,7 @@ def audit(
22792279
snapshot=snapshot,
22802280
start=start,
22812281
end=end,
2282+
execution_time=execution_time,
22822283
snapshots=self.snapshots,
22832284
):
22842285
audit_id = f"{audit_result.audit.name}"

sqlmesh/core/engine_adapter/fabric.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from functools import cached_property
88
from sqlglot import exp
99
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result
10-
from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
1110
from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
1211
from sqlmesh.core.engine_adapter.shared import (
1312
InsertOverwriteStrategy,
@@ -21,7 +20,7 @@
2120
logger = logging.getLogger(__name__)
2221

2322

24-
class FabricEngineAdapter(LogicalMergeMixin, MSSQLEngineAdapter):
23+
class FabricEngineAdapter(MSSQLEngineAdapter):
2524
"""
2625
Adapter for Microsoft Fabric.
2726
"""

sqlmesh/core/renderer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,14 @@ def _resolve_table(table: str | exp.Table) -> str:
196196
**kwargs,
197197
}
198198

199+
if this_model:
200+
render_kwargs["this_model"] = this_model
201+
202+
macro_evaluator.locals.update(render_kwargs)
203+
199204
variables = kwargs.pop("variables", {})
205+
if variables:
206+
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
200207

201208
expressions = [self._expression]
202209
if isinstance(self._expression, d.Jinja):
@@ -268,14 +275,6 @@ def _resolve_table(table: str | exp.Table) -> str:
268275
f"Could not parse the rendered jinja at '{self._path}'.\n{ex}"
269276
) from ex
270277

271-
if this_model:
272-
render_kwargs["this_model"] = this_model
273-
274-
macro_evaluator.locals.update(render_kwargs)
275-
276-
if variables:
277-
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
278-
279278
for definition in self._macro_definitions:
280279
try:
281280
macro_evaluator.evaluate(definition)

sqlmesh/core/scheduler.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def _dag(
659659
}
660660
snapshots_to_create = snapshots_to_create or set()
661661
original_snapshots_to_create = snapshots_to_create.copy()
662+
upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {}
662663

663664
snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
664665
dag = DAG[SchedulingUnit]()
@@ -670,12 +671,15 @@ def _dag(
670671
snapshot = self.snapshots_by_name[snapshot_id.name]
671672
intervals = intervals_per_snapshot.get(snapshot.name, [])
672673

673-
upstream_dependencies: t.List[SchedulingUnit] = []
674+
upstream_dependencies: t.Set[SchedulingUnit] = set()
674675

675676
for p_sid in snapshot.parents:
676-
upstream_dependencies.extend(
677+
upstream_dependencies.update(
677678
self._find_upstream_dependencies(
678-
p_sid, intervals_per_snapshot, original_snapshots_to_create
679+
p_sid,
680+
intervals_per_snapshot,
681+
original_snapshots_to_create,
682+
upstream_dependencies_cache,
679683
)
680684
)
681685

@@ -726,29 +730,42 @@ def _find_upstream_dependencies(
726730
parent_sid: SnapshotId,
727731
intervals_per_snapshot: t.Dict[str, Intervals],
728732
snapshots_to_create: t.Set[SnapshotId],
729-
) -> t.List[SchedulingUnit]:
733+
cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]],
734+
) -> t.Set[SchedulingUnit]:
730735
if parent_sid not in self.snapshots:
731-
return []
736+
return set()
737+
if parent_sid in cache:
738+
return cache[parent_sid]
732739

733740
p_intervals = intervals_per_snapshot.get(parent_sid.name, [])
734741

742+
parent_node: t.Optional[SchedulingUnit] = None
735743
if p_intervals:
736744
if len(p_intervals) > 1:
737-
return [DummyNode(snapshot_name=parent_sid.name)]
738-
interval = p_intervals[0]
739-
return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
740-
if parent_sid in snapshots_to_create:
741-
return [CreateNode(snapshot_name=parent_sid.name)]
745+
parent_node = DummyNode(snapshot_name=parent_sid.name)
746+
else:
747+
interval = p_intervals[0]
748+
parent_node = EvaluateNode(
749+
snapshot_name=parent_sid.name, interval=interval, batch_index=0
750+
)
751+
elif parent_sid in snapshots_to_create:
752+
parent_node = CreateNode(snapshot_name=parent_sid.name)
753+
754+
if parent_node is not None:
755+
cache[parent_sid] = {parent_node}
756+
return {parent_node}
757+
742758
# This snapshot has no intervals and doesn't need creation which means
743759
# that it can be a transitive dependency
744-
transitive_deps: t.List[SchedulingUnit] = []
760+
transitive_deps: t.Set[SchedulingUnit] = set()
745761
parent_snapshot = self.snapshots[parent_sid]
746762
for grandparent_sid in parent_snapshot.parents:
747-
transitive_deps.extend(
763+
transitive_deps.update(
748764
self._find_upstream_dependencies(
749-
grandparent_sid, intervals_per_snapshot, snapshots_to_create
765+
grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache
750766
)
751767
)
768+
cache[parent_sid] = transitive_deps
752769
return transitive_deps
753770

754771
def _run_or_audit(

sqlmesh/core/test/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def runTest(self) -> None:
807807
actual_df.reset_index(drop=True, inplace=True)
808808
expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial)
809809

810-
self.assert_equal(expected, actual_df, sort=False, partial=partial)
810+
self.assert_equal(expected, actual_df, sort=True, partial=partial)
811811

812812
def _execute_model(self) -> pd.DataFrame:
813813
"""Executes the python model and returns a DataFrame."""

sqlmesh/dbt/common.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def load_yaml(source: str | Path) -> t.Dict:
4646
raise ConfigError(f"{source}: {ex}" if isinstance(source, Path) else f"{ex}")
4747

4848

49-
def parse_meta(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
49+
def parse_meta(v: t.Optional[t.Dict[str, t.Any]]) -> t.Dict[str, t.Any]:
50+
if v is None:
51+
return {}
5052
for key, value in v.items():
5153
if isinstance(value, str):
5254
v[key] = try_str_to_bool(value)
@@ -115,7 +117,7 @@ def _validate_list(cls, v: t.Union[str, t.List[str]]) -> t.List[str]:
115117

116118
@field_validator("meta", mode="before")
117119
@classmethod
118-
def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.Any]:
120+
def _validate_meta(cls, v: t.Optional[t.Dict[str, t.Union[str, t.Any]]]) -> t.Dict[str, t.Any]:
119121
return parse_meta(v)
120122

121123
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {

tests/core/engine_adapter/test_fabric.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import typing as t
44

5+
import pandas as pd # noqa: TID253
56
import pytest
67
from pytest_mock import MockerFixture
78
from sqlglot import exp, parse_one
@@ -143,3 +144,143 @@ def test_alter_table_direct_alteration(adapter: FabricEngineAdapter, mocker: Moc
143144
]
144145

145146
assert to_sql_calls(adapter) == expected_calls
147+
def test_merge_pandas(
148+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
149+
):
150+
mocker.patch(
151+
"sqlmesh.core.engine_adapter.fabric.FabricEngineAdapter.table_exists",
152+
return_value=False,
153+
)
154+
155+
adapter = make_mocked_engine_adapter(FabricEngineAdapter)
156+
157+
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
158+
table_name = "target"
159+
temp_table_id = "abcdefgh"
160+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
161+
162+
df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]})
163+
164+
# 1 key
165+
adapter.merge(
166+
target_table=table_name,
167+
source_table=df,
168+
target_columns_to_types={
169+
"id": exp.DataType.build("int"),
170+
"ts": exp.DataType.build("TIMESTAMP"),
171+
"val": exp.DataType.build("int"),
172+
},
173+
unique_key=[exp.to_identifier("id")],
174+
)
175+
adapter._connection_pool.get().bulk_copy.assert_called_with(
176+
f"__temp_target_{temp_table_id}", [(1, 1, 4), (2, 2, 5), (3, 3, 6)]
177+
)
178+
179+
assert to_sql_calls(adapter) == [
180+
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""",
181+
f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);",
182+
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
183+
]
184+
185+
# 2 keys
186+
adapter.cursor.reset_mock()
187+
adapter._connection_pool.get().reset_mock()
188+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
189+
adapter.merge(
190+
target_table=table_name,
191+
source_table=df,
192+
target_columns_to_types={
193+
"id": exp.DataType.build("int"),
194+
"ts": exp.DataType.build("TIMESTAMP"),
195+
"val": exp.DataType.build("int"),
196+
},
197+
unique_key=[exp.to_identifier("id"), exp.to_column("ts")],
198+
)
199+
adapter._connection_pool.get().bulk_copy.assert_called_with(
200+
f"__temp_target_{temp_table_id}", [(1, 1, 4), (2, 2, 5), (3, 3, 6)]
201+
)
202+
203+
assert to_sql_calls(adapter) == [
204+
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""",
205+
f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);",
206+
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
207+
]
208+
209+
210+
def test_merge_exists(
211+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
212+
):
213+
mocker.patch(
214+
"sqlmesh.core.engine_adapter.fabric.FabricEngineAdapter.table_exists",
215+
return_value=False,
216+
)
217+
218+
adapter = make_mocked_engine_adapter(FabricEngineAdapter)
219+
220+
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
221+
table_name = "target"
222+
temp_table_id = "abcdefgh"
223+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
224+
225+
df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]})
226+
227+
# regular implementation
228+
adapter.merge(
229+
target_table=table_name,
230+
source_table=df,
231+
target_columns_to_types={
232+
"id": exp.DataType.build("int"),
233+
"ts": exp.DataType.build("TIMESTAMP"),
234+
"val": exp.DataType.build("int"),
235+
},
236+
unique_key=[exp.to_identifier("id")],
237+
)
238+
239+
assert to_sql_calls(adapter) == [
240+
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""",
241+
f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);",
242+
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
243+
]
244+
245+
# merge exists implementation
246+
adapter.cursor.reset_mock()
247+
adapter._connection_pool.get().reset_mock()
248+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
249+
adapter.merge(
250+
target_table=table_name,
251+
source_table=df,
252+
target_columns_to_types={
253+
"id": exp.DataType.build("int"),
254+
"ts": exp.DataType.build("TIMESTAMP"),
255+
"val": exp.DataType.build("int"),
256+
},
257+
unique_key=[exp.to_identifier("id")],
258+
physical_properties={"mssql_merge_exists": True},
259+
)
260+
261+
assert to_sql_calls(adapter) == [
262+
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""",
263+
f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED AND EXISTS(SELECT [__MERGE_TARGET__].[ts], [__MERGE_TARGET__].[val] EXCEPT SELECT [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]) THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);",
264+
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
265+
]
266+
267+
# merge exists and all model columns are keys
268+
adapter.cursor.reset_mock()
269+
adapter._connection_pool.get().reset_mock()
270+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
271+
adapter.merge(
272+
target_table=table_name,
273+
source_table=df,
274+
target_columns_to_types={
275+
"id": exp.DataType.build("int"),
276+
"ts": exp.DataType.build("TIMESTAMP"),
277+
},
278+
unique_key=[exp.to_identifier("id"), exp.to_column("ts")],
279+
physical_properties={"mssql_merge_exists": True},
280+
)
281+
282+
assert to_sql_calls(adapter) == [
283+
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6))');""",
284+
f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN NOT MATCHED THEN INSERT ([id], [ts]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts]);",
285+
f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];",
286+
]

tests/core/test_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12158,3 +12158,21 @@ def test_grants_empty_values():
1215812158
def test_grants_table_type(kind: t.Union[str, _ModelKind], expected: DataObjectType):
1215912159
model = create_sql_model("test_table", parse_one("SELECT 1 as id"), kind=kind)
1216012160
assert model.grants_table_type == expected
12161+
12162+
12163+
def test_model_macro_using_locals_called_from_jinja(assert_exp_eq) -> None:
12164+
@macro()
12165+
def execution_date(evaluator):
12166+
return f"""'{evaluator.locals.get("execution_date")}'"""
12167+
12168+
expressions = d.parse(
12169+
"""
12170+
MODEL (name db.table);
12171+
12172+
JINJA_QUERY_BEGIN;
12173+
SELECT {{ execution_date() }} AS col;
12174+
JINJA_END;
12175+
"""
12176+
)
12177+
model = load_sql_based_model(expressions)
12178+
assert_exp_eq(model.render_query(), '''SELECT '1970-01-01' AS "col"''')

0 commit comments

Comments
 (0)