Skip to content

Commit 64609d0

Browse files
authored
Merge branch 'SQLMesh:main' into comment_suppport_mssql
2 parents da6f6a2 + 192fbe9 commit 64609d0

6 files changed

Lines changed: 350 additions & 1 deletion

File tree

docs/integrations/engines/databricks.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,28 @@ The only relevant SQLMesh configuration parameter is the optional `catalog` para
271271
| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N |
272272
| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N |
273273

274+
### Query tags
275+
276+
Databricks SQL Connector supports per-query tags through the `query_tags` model session property. Specify tags as a `MAP(...)` of string keys to string or `NULL` values:
277+
278+
```sql
279+
MODEL (
280+
name sqlmesh_example.tagged_model,
281+
dialect databricks,
282+
session_properties (
283+
query_tags = MAP(
284+
'team', 'data-eng',
285+
'app', 'sqlmesh',
286+
'feature', NULL
287+
)
288+
)
289+
);
290+
291+
SELECT 1 AS id;
292+
```
293+
294+
Query tags are only applied when SQLMesh executes SQL through the Databricks SQL Connector. They are not applied when SQLMesh routes execution through Databricks Connect, a Databricks notebook SparkSession, or the Spark engine adapter.
295+
274296
## Model table properties to support altering tables
275297

276298
If you are making a change to the structure of a table that is [forward only](../../guides/incremental_time.md#forward-only-models), then you may need to add the following to your model's `physical_properties`:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ bigquery = [
5151
# pinned an older SQLGlot which is incompatible with SQLMesh
5252
bigframes = ["bigframes>=1.32.0"]
5353
clickhouse = ["clickhouse-connect"]
54-
databricks = ["databricks-sql-connector[pyarrow]"]
54+
databricks = ["databricks-sql-connector[pyarrow]>=4.2.6"]
5555
dev = [
5656
"agate",
5757
"beautifulsoup4",

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,43 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33+
def _query_tags(
34+
query_tags: t.Optional[t.Union[exp.Expr, str, int, float, bool]],
35+
) -> t.Optional[t.Dict[str, t.Optional[str]]]:
36+
if not query_tags:
37+
return None
38+
39+
if not isinstance(query_tags, (exp.Map, exp.VarMap)):
40+
raise SQLMeshError("Invalid value for `session_properties.query_tags`. Must be a map.")
41+
42+
keys = query_tags.args.get("keys")
43+
values = query_tags.args.get("values")
44+
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
45+
raise SQLMeshError(
46+
"Invalid value for `session_properties.query_tags`. Must be a map with array "
47+
"keys and array values."
48+
)
49+
50+
tags: t.Dict[str, t.Optional[str]] = {}
51+
for key, value in zip(keys.expressions, values.expressions):
52+
if not isinstance(key, exp.Literal) or not key.is_string:
53+
raise SQLMeshError(
54+
"Invalid key in `session_properties.query_tags`. Keys must be string literals."
55+
)
56+
57+
if isinstance(value, exp.Null):
58+
tags[key.this] = None
59+
elif isinstance(value, exp.Literal) and value.is_string:
60+
tags[key.this] = value.this
61+
else:
62+
raise SQLMeshError(
63+
"Invalid value in `session_properties.query_tags`. Values must be string "
64+
"literals or NULL."
65+
)
66+
67+
return tags
68+
69+
3370
class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
3471
DIALECT = "databricks"
3572
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
@@ -98,6 +135,12 @@ def _use_spark_session(self) -> bool:
98135
def is_spark_session_connection(self) -> bool:
99136
return isinstance(self.connection, SparkSessionConnection)
100137

138+
@property
139+
def _is_databricks_sql_connector_connection(self) -> bool:
140+
return not self.is_spark_session_connection and not self._connection_pool.get_attribute(
141+
"use_spark_engine_adapter"
142+
)
143+
101144
def _set_spark_engine_adapter_if_needed(self) -> None:
102145
self._spark_engine_adapter = None
103146

@@ -181,10 +224,23 @@ def _begin_session(self, properties: SessionProperties) -> t.Any:
181224
"""Begin a new session."""
182225
# Align the different possible connectors to a single catalog
183226
self.set_current_catalog(self.default_catalog) # type: ignore
227+
self._connection_pool.set_attribute("query_tags", _query_tags(properties.get("query_tags")))
184228

185229
def _end_session(self) -> None:
230+
self._connection_pool.set_attribute("query_tags", None)
186231
self._connection_pool.set_attribute("use_spark_engine_adapter", False)
187232

233+
def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None:
234+
query_tags = self._connection_pool.get_attribute("query_tags")
235+
if (
236+
query_tags
237+
and "query_tags" not in kwargs
238+
and self._is_databricks_sql_connector_connection
239+
):
240+
kwargs["query_tags"] = query_tags
241+
242+
return super()._execute(sql, track_rows_processed, **kwargs)
243+
188244
def _df_to_source_queries(
189245
self,
190246
df: DF,

sqlmesh/core/model/meta.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,38 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
396396
raise ConfigError(
397397
"Invalid value for `session_properties.authorization`. Must be a string literal."
398398
)
399+
elif prop_name == "query_tags":
400+
query_tags = eq.right
401+
if isinstance(query_tags, (d.MacroFunc, d.MacroVar)):
402+
continue
403+
404+
if not isinstance(query_tags, (exp.Map, exp.VarMap)):
405+
raise ConfigError(
406+
"Invalid value for `session_properties.query_tags`. Must be a map."
407+
)
408+
409+
keys = query_tags.args.get("keys")
410+
values = query_tags.args.get("values")
411+
if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
412+
raise ConfigError(
413+
"Invalid value for `session_properties.query_tags`. Must be a map with array "
414+
"keys and array values."
415+
)
416+
417+
for key, value in zip(keys.expressions, values.expressions):
418+
if not isinstance(key, exp.Literal) or not key.is_string:
419+
raise ConfigError(
420+
"Invalid key in `session_properties.query_tags`. Keys must be string literals."
421+
)
422+
423+
if not (
424+
isinstance(value, exp.Null)
425+
or (isinstance(value, exp.Literal) and value.is_string)
426+
):
427+
raise ConfigError(
428+
"Invalid value in `session_properties.query_tags`. Values must be string "
429+
"literals or NULL."
430+
)
399431

400432
return parsed_session_properties
401433

tests/core/engine_adapter/test_databricks.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,23 @@
1010
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
1111
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
1212
from sqlmesh.core.node import IntervalUnit
13+
from sqlmesh.utils.errors import SQLMeshError
1314
from tests.core.engine_adapter import to_sql_calls
1415

1516
pytestmark = [pytest.mark.databricks, pytest.mark.engine]
1617

1718

19+
def _query_tags_map(*items: t.Optional[str]) -> exp.Map:
20+
return exp.Map(
21+
keys=exp.Array(expressions=[exp.Literal.string(item) for item in items[::2]]),
22+
values=exp.Array(
23+
expressions=[
24+
exp.Null() if item is None else exp.Literal.string(item) for item in items[1::2]
25+
]
26+
),
27+
)
28+
29+
1830
def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
1931
mocker.patch(
2032
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists",
@@ -117,6 +129,120 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.
117129
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"]
118130

119131

132+
def test_session_query_tags(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
133+
mocker.patch(
134+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
135+
)
136+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
137+
138+
with adapter.session(
139+
{
140+
"query_tags": d.parse_one(
141+
"MAP('team', 'data-eng', 'app', 'sqlmesh')", dialect="databricks"
142+
)
143+
}
144+
):
145+
adapter.execute("SELECT 1")
146+
147+
adapter.cursor.execute.assert_called_with(
148+
"SELECT 1", query_tags={"team": "data-eng", "app": "sqlmesh"}
149+
)
150+
151+
adapter.execute("SELECT 2")
152+
153+
adapter.cursor.execute.assert_called_with("SELECT 2")
154+
155+
156+
def test_session_query_tags_allow_none_values(
157+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
158+
):
159+
mocker.patch(
160+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
161+
)
162+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
163+
164+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng", "feature", None)}):
165+
adapter.execute("SELECT 1")
166+
167+
adapter.cursor.execute.assert_called_with(
168+
"SELECT 1", query_tags={"team": "data-eng", "feature": None}
169+
)
170+
171+
172+
def test_session_query_tags_do_not_override_explicit_query_tags(
173+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
174+
):
175+
mocker.patch(
176+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
177+
)
178+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
179+
180+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
181+
adapter.execute("SELECT 1", query_tags={"team": "analytics"})
182+
183+
adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "analytics"})
184+
185+
186+
def test_session_query_tags_not_applied_to_spark_session_connection(
187+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
188+
):
189+
mocker.patch(
190+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
191+
)
192+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
193+
mocker.patch.object(
194+
DatabricksEngineAdapter,
195+
"is_spark_session_connection",
196+
new_callable=mocker.PropertyMock,
197+
return_value=True,
198+
)
199+
200+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
201+
adapter.execute("SELECT 1")
202+
203+
adapter.cursor.execute.assert_called_with("SELECT 1")
204+
205+
206+
def test_session_query_tags_not_applied_to_spark_engine_adapter(
207+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
208+
):
209+
mocker.patch(
210+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
211+
)
212+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
213+
spark_cursor = mocker.Mock()
214+
adapter._spark_engine_adapter = mocker.Mock(cursor=spark_cursor)
215+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
216+
217+
with adapter.session({"query_tags": _query_tags_map("team", "data-eng")}):
218+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
219+
adapter.execute("SELECT 1")
220+
221+
spark_cursor.execute.assert_called_with("SELECT 1")
222+
223+
224+
@pytest.mark.parametrize(
225+
"query_tags",
226+
[
227+
"team:data-eng",
228+
exp.Map(
229+
keys=exp.Array(expressions=[exp.Literal.number(1)]),
230+
values=exp.Array(expressions=[exp.Literal.string("data-eng")]),
231+
),
232+
exp.Map(
233+
keys=exp.Array(expressions=[exp.Literal.string("team")]),
234+
values=exp.Array(expressions=[exp.Literal.number(1)]),
235+
),
236+
],
237+
)
238+
def test_session_query_tags_invalid(query_tags, make_mocked_engine_adapter: t.Callable):
239+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
240+
241+
with pytest.raises(SQLMeshError, match="session_properties.query_tags"):
242+
with adapter.session({"query_tags": query_tags}):
243+
pass
244+
245+
120246
def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
121247
mocker.patch(
122248
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"

0 commit comments

Comments
 (0)