Skip to content

Commit 8edfae4

Browse files
authored
Fix: Support macros for 'authorization' and 'query_label' keys in session properties (#4660)
1 parent 5da8a11 commit 8edfae4

File tree

6 files changed

+68
-5
lines changed

6 files changed

+68
-5
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def _begin_session(self, properties: SessionProperties) -> None:
199199
(label_tuple.expressions[0].name, label_tuple.expressions[1].name)
200200
for label_tuple in label_tuples
201201
)
202+
elif query_label_property is not None:
203+
raise SQLMeshError(
204+
"Invalid value for `session_properties.query_label`. Must be an array or tuple."
205+
)
202206

203207
if parsed_query_label:
204208
query_label_str = ",".join([":".join(label) for label in parsed_query_label])

sqlmesh/core/engine_adapter/trino.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
set_catalog,
2828
)
2929
from sqlmesh.core.schema_diff import SchemaDiffer
30+
from sqlmesh.utils.errors import SQLMeshError
3031
from sqlmesh.utils.date import TimeLike
3132

3233
if t.TYPE_CHECKING:
@@ -99,6 +100,11 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
99100
if not isinstance(authorization, exp.Expression):
100101
authorization = exp.Literal.string(authorization)
101102

103+
if not authorization.is_string:
104+
raise SQLMeshError(
105+
"Invalid value for `session_properties.authorization`. Must be a string literal."
106+
)
107+
102108
authorization_sql = authorization.sql(dialect=self.dialect)
103109

104110
self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")

sqlmesh/core/model/meta.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,8 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
323323

324324
if prop_name == "query_label":
325325
query_label = eq.right
326-
if not (
327-
isinstance(query_label, exp.Array)
328-
or isinstance(query_label, exp.Tuple)
329-
or isinstance(query_label, exp.Paren)
326+
if not isinstance(
327+
query_label, (exp.Array, exp.Tuple, exp.Paren, d.MacroFunc, d.MacroVar)
330328
):
331329
raise ConfigError(
332330
"Invalid value for `session_properties.query_label`. Must be an array or tuple."
@@ -349,7 +347,9 @@ def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
349347
)
350348
elif prop_name == "authorization":
351349
authorization = eq.right
352-
if not (isinstance(authorization, exp.Literal) and authorization.is_string):
350+
if not (
351+
isinstance(authorization, exp.Literal) and authorization.is_string
352+
) and not isinstance(authorization, (d.MacroFunc, d.MacroVar)):
353353
raise ConfigError(
354354
"Invalid value for `session_properties.authorization`. Must be a string literal."
355355
)

tests/core/engine_adapter/test_bigquery.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlmesh.core.engine_adapter.bigquery import select_partitions_expr
1515
from sqlmesh.core.node import IntervalUnit
1616
from sqlmesh.utils import AttributeDict
17+
from sqlmesh.utils.errors import SQLMeshError
1718

1819
pytestmark = [pytest.mark.bigquery, pytest.mark.engine]
1920

@@ -564,6 +565,14 @@ def test_begin_end_session(mocker: MockerFixture):
564565
begin_new_session_call = connection_mock._client.query.call_args_list[5]
565566
assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1";SELECT 1;'
566567

568+
# test invalid query_label value
569+
with pytest.raises(
570+
SQLMeshError,
571+
match="Invalid value for `session_properties.query_label`. Must be an array or tuple.",
572+
):
573+
with adapter.session({"query_label": parse_one("'key1:value1'")}):
574+
adapter.execute("SELECT 6;")
575+
567576

568577
def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]:
569578
output = []

tests/core/engine_adapter/test_trino.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.model import load_sql_based_model
1212
from sqlmesh.core.model.definition import SqlModel
1313
from sqlmesh.core.dialect import schema_
14+
from sqlmesh.utils.errors import SQLMeshError
1415
from tests.core.engine_adapter import to_sql_calls
1516

1617
pytestmark = [pytest.mark.engine, pytest.mark.trino]
@@ -632,6 +633,14 @@ def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter):
632633
except RuntimeError:
633634
pass
634635

636+
# Test 5: Invalid authorization value
637+
with pytest.raises(
638+
SQLMeshError,
639+
match="Invalid value for `session_properties.authorization`. Must be a string literal.",
640+
):
641+
with adapter.session({"authorization": exp.Literal.number(1)}):
642+
adapter.execute("SELECT 1")
643+
635644
assert to_sql_calls(adapter) == [
636645
"SET SESSION AUTHORIZATION 'test_user'",
637646
"SELECT 1",

tests/core/test_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10244,3 +10244,38 @@ def test_invalid_sql_model_query() -> None:
1024410244
match=r"^A query is required and must be a SELECT statement, a UNION statement, or a JINJA_QUERY block.*",
1024510245
):
1024610246
load_sql_based_model(expressions)
10247+
10248+
10249+
def test_query_label_and_authorization_macro():
10250+
@macro()
10251+
def test_query_label_macro(evaluator):
10252+
return "[('key', 'value')]"
10253+
10254+
@macro()
10255+
def test_authorization_macro(evaluator):
10256+
return exp.Literal.string("test_authorization")
10257+
10258+
expressions = d.parse(
10259+
"""
10260+
MODEL (
10261+
name db.table,
10262+
session_properties (
10263+
query_label = @test_query_label_macro(),
10264+
authorization = @test_authorization_macro()
10265+
)
10266+
);
10267+
10268+
SELECT 1 AS c;
10269+
"""
10270+
)
10271+
10272+
model = load_sql_based_model(expressions)
10273+
assert model.session_properties == {
10274+
"query_label": d.parse_one("@test_query_label_macro()"),
10275+
"authorization": d.parse_one("@test_authorization_macro()"),
10276+
}
10277+
10278+
assert model.render_session_properties() == {
10279+
"query_label": d.parse_one("[('key', 'value')]"),
10280+
"authorization": d.parse_one("'test_authorization'"),
10281+
}

0 commit comments

Comments
 (0)