Skip to content

Commit 6f601d4

Browse files
authored
fix: set airflow default catalog (#1923)
1 parent a9834b3 commit 6f601d4

File tree

5 files changed

+37
-2
lines changed

5 files changed

+37
-2
lines changed

sqlmesh/engines/spark/app.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ def get_or_create_spark_session(dialect: str) -> SparkSession:
3030

3131

3232
def main(
33-
dialect: str, command_type: commands.CommandType, ddl_concurrent_tasks: int, payload_path: str
33+
dialect: str,
34+
default_catalog: str,
35+
command_type: commands.CommandType,
36+
ddl_concurrent_tasks: int,
37+
payload_path: str,
3438
) -> None:
3539
if dialect not in ("databricks", "spark"):
3640
raise NotSupportedError(
@@ -50,6 +54,7 @@ def main(
5054
create_engine_adapter(
5155
lambda: spark_session_db.connection(spark),
5256
dialect,
57+
default_catalog=default_catalog,
5358
multithreaded=ddl_concurrent_tasks > 1,
5459
execute_log_level=logging.INFO,
5560
),
@@ -79,6 +84,10 @@ def main(
7984
"--dialect",
8085
help="The dialect to use when creating the engine adapter.",
8186
)
87+
parser.add_argument(
88+
"--default_catalog",
89+
help="The default catalog to use when creating the engine adapter.",
90+
)
8291
parser.add_argument(
8392
"--command_type",
8493
type=commands.CommandType,
@@ -96,4 +105,10 @@ def main(
96105
help="Path to the payload object. Can be a local or remote path.",
97106
)
98107
args = parser.parse_args()
99-
main(args.dialect, args.command_type, args.ddl_concurrent_tasks, args.payload_path)
108+
main(
109+
args.dialect,
110+
args.default_catalog,
111+
args.command_type,
112+
args.ddl_concurrent_tasks,
113+
args.payload_path,
114+
)

sqlmesh/schedulers/airflow/operators/databricks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def execute(self, context: Context) -> None:
101101
)
102102
task_arguments = {
103103
"dialect": "databricks",
104+
"default_catalog": self._target.default_catalog,
104105
"command_type": self._target.command_type.value if self._target.command_type else None,
105106
"ddl_concurrent_tasks": self._target.ddl_concurrent_tasks,
106107
"payload_path": remote_payload_path,

sqlmesh/schedulers/airflow/operators/spark_submit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def _get_hook(
136136
) -> SparkSubmitHook:
137137
application_args = {
138138
"dialect": "spark",
139+
"default_catalog": self._target.default_catalog,
139140
"command_type": command_type.value if command_type else None,
140141
"ddl_concurrent_tasks": ddl_concurrent_tasks,
141142
"payload_path": command_payload_file_path.split("/")[-1]

sqlmesh/schedulers/airflow/operators/targets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
from airflow.exceptions import AirflowSkipException
6+
from airflow.models import Variable
67
from airflow.utils.context import Context
78
from airflow.utils.session import provide_session
89
from sqlalchemy.orm import Session
@@ -29,6 +30,10 @@ class BaseTarget(abc.ABC, t.Generic[CP]):
2930
command_handler: t.Callable[[SnapshotEvaluator, CP], None]
3031
ddl_concurrent_tasks: int
3132

33+
@property
34+
def default_catalog(self) -> str:
35+
return Variable.get(common.DEFAULT_CATALOG_VARIABLE_NAME)
36+
3237
def serialized_command_payload(self, context: Context) -> str:
3338
"""Returns the serialized command payload for the Spark application.
3439
@@ -62,6 +67,7 @@ def execute(
6267
dialect,
6368
multithreaded=self.ddl_concurrent_tasks > 1,
6469
execute_log_level=logging.INFO,
70+
default_catalog=self.default_catalog,
6571
**kwargs,
6672
),
6773
ddl_concurrent_tasks=self.ddl_concurrent_tasks,

tests/schedulers/airflow/operators/test_targets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def test_evaluation_target_execute(mocker: MockerFixture, make_snapshot: t.Calla
4848
"sqlmesh.core.state_sync.engine_adapter.EngineAdapterStateSync.add_interval"
4949
)
5050

51+
variable_get_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets.Variable.get")
52+
53+
variable_get_mock.return_value = "default_catalog"
54+
5155
snapshot = make_snapshot(model)
5256
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
5357
parent_snapshots = {snapshot.name: snapshot}
@@ -83,6 +87,10 @@ def test_evaluation_target_execute_seed_model(mocker: MockerFixture, make_snapsh
8387
dag_run_mock.data_interval_end = interval_ds
8488
dag_run_mock.logical_date = logical_ds
8589

90+
variable_get_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets.Variable.get")
91+
92+
variable_get_mock.return_value = "default_catalog"
93+
8694
context = Context(dag_run=dag_run_mock) # type: ignore
8795

8896
snapshot = make_snapshot(
@@ -150,6 +158,10 @@ def test_cleanup_target_execute(mocker: MockerFixture, make_snapshot: t.Callable
150158
task_instance_mock = mocker.Mock()
151159
task_instance_mock.xcom_pull.return_value = command.json()
152160

161+
variable_get_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets.Variable.get")
162+
163+
variable_get_mock.return_value = "default_catalog"
164+
153165
context = Context(ti=task_instance_mock) # type: ignore
154166

155167
evaluator_cleanup_mock = mocker.patch(

0 commit comments

Comments
 (0)