Skip to content

Commit 87f7461

Browse files
committed
feat: add catalog type overrides
1 parent 906be74 commit 87f7461

File tree

6 files changed

+139
-15
lines changed

6 files changed

+139
-15
lines changed

docs/integrations/engines/trino.md

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,21 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/
8181

8282
### Connection options
8383

84-
| Option | Description | Type | Required |
85-
|----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
86-
| `type` | Engine type name - must be `trino` | string | Y |
87-
| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y |
88-
| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y |
89-
| `catalog` | The name of a catalog in your cluster. | string | Y |
90-
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
91-
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
92-
| `roles` | Mapping of catalog name to a role | dict | N |
93-
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
94-
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
95-
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
96-
| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N |
84+
| Option | Description | Type | Required |
85+
|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:|
86+
| `type` | Engine type name - must be `trino` | string | Y |
87+
| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y |
88+
| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y |
89+
| `catalog` | The name of a catalog in your cluster. | string | Y |
90+
| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N |
91+
| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N |
92+
| `roles` | Mapping of catalog name to a role | dict | N |
93+
| `http_headers` | Additional HTTP headers to send with each request. | dict | N |
94+
| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N |
95+
| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N |
96+
| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N |
97+
| `schema_location_mapping` | A mapping of regex patterns to S3 locations to use for the `LOCATION` property when creating schemas. See [Table and Schema locations](#table-and-schema-locations) for more details. | dict | N |
98+
| `catalog_type_overrides` | A mapping of catalog names to their connector type. This is used to enable/disable connector specific behavior. See [Catalog Type Overrides)(#catalog-type-overrides) for more details. | dict | N |
9799

98100
## Table and Schema locations
99101

@@ -204,6 +206,25 @@ SELECT ...
204206

205207
This will cause SQLMesh to set the specified `LOCATION` when issuing a `CREATE TABLE` statement.
206208

209+
## Catalog Type Overrides
210+
211+
SQLMesh attempts to determine the connector type of a catalog by querying the `system.metadata.catalogs` table and checking the `connector_name` column.
212+
It checks if the connector name is `hive` for Hive connector behavior or contains `iceberg` or `delta_lake` for Iceberg or Delta Lake connector behavior respectively.
213+
However, the connector name may not always be a reliable way to determine the connector type, for example when using a custom connector or a fork of an existing connector.
214+
To handle such cases, you can use the `catalog_type_overrides` connection property to explicitly specify the connector type for specific catalogs.
215+
For example, to specify that the `datalake` catalog is using the Iceberg connector and the `analytics` catalog is using the Hive connector, you can configure the connection as follows:
216+
217+
```yaml title="config.yaml"
218+
gateways:
219+
trino:
220+
connection:
221+
type: trino
222+
...
223+
catalog_type_overrides:
224+
datalake: iceberg
225+
analytics: hive
226+
```
227+
207228
## Authentication
208229

209230
=== "No Auth"

sqlmesh/core/config/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class ConnectionConfig(abc.ABC, BaseConfig):
101101
pre_ping: bool
102102
pretty_sql: bool = False
103103
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
104+
catalog_type_overrides: t.Optional[t.Dict[str, str]] = None
104105

105106
# Whether to share a single connection across threads or create a new connection per thread.
106107
shared_connection: t.ClassVar[bool] = False
@@ -176,6 +177,7 @@ def create_engine_adapter(
176177
pretty_sql=self.pretty_sql,
177178
shared_connection=self.shared_connection,
178179
schema_differ_overrides=self.schema_differ_overrides,
180+
catalog_type_overrides=self.catalog_type_overrides,
179181
**self._extra_engine_config,
180182
)
181183

sqlmesh/core/engine_adapter/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def schema_differ(self) -> SchemaDiffer:
223223
}
224224
)
225225

226+
@property
227+
def _catalog_type_overrides(self) -> t.Dict[str, str]:
228+
return self._extra_config.get("catalog_type_overrides", {})
229+
226230
@classmethod
227231
def _casted_columns(
228232
cls,
@@ -430,7 +434,11 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str:
430434
raise UnsupportedCatalogOperationError(
431435
f"{self.dialect} does not support catalogs and a catalog was provided: {catalog}"
432436
)
433-
return self.DEFAULT_CATALOG_TYPE
437+
return (
438+
self._catalog_type_overrides.get(catalog, self.DEFAULT_CATALOG_TYPE)
439+
if catalog
440+
else self.DEFAULT_CATALOG_TYPE
441+
)
434442

435443
def get_catalog_type_from_table(self, table: TableName) -> str:
436444
"""Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type"""

sqlmesh/core/engine_adapter/trino.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class TrinoEngineAdapter(
7171
MAX_TIMESTAMP_PRECISION = 3
7272

7373
@property
74-
def schema_location_mapping(self) -> t.Optional[dict[re.Pattern, str]]:
74+
def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]:
7575
return self._extra_config.get("schema_location_mapping")
7676

7777
@property
@@ -86,6 +86,8 @@ def set_current_catalog(self, catalog: str) -> None:
8686
def get_catalog_type(self, catalog: t.Optional[str]) -> str:
8787
row: t.Tuple = tuple()
8888
if catalog:
89+
if catalog_type_overrides := self._catalog_type_overrides.get(catalog):
90+
return catalog_type_overrides
8991
row = (
9092
self.fetchone(
9193
f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"

tests/core/engine_adapter/test_trino.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.model import load_sql_based_model
1212
from sqlmesh.core.model.definition import SqlModel
1313
from sqlmesh.core.dialect import schema_
14+
from sqlmesh.utils.date import to_ds
1415
from sqlmesh.utils.errors import SQLMeshError
1516
from tests.core.engine_adapter import to_sql_calls
1617

@@ -683,3 +684,74 @@ def test_replace_table_catalog_support(
683684
sql_calls[0]
684685
== f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"'
685686
)
687+
688+
689+
@pytest.mark.parametrize(
690+
"catalog_type_overrides", [{}, {"my_catalog": "hive"}, {"other_catalog": "iceberg"}]
691+
)
692+
def test_insert_overwrite_time_partition_hive(
693+
make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str]
694+
):
695+
config = TrinoConnectionConfig(
696+
user="user",
697+
host="host",
698+
catalog="catalog",
699+
catalog_type_overrides=catalog_type_overrides,
700+
)
701+
adapter: TrinoEngineAdapter = make_mocked_engine_adapter(
702+
TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides
703+
)
704+
adapter.fetchone = MagicMock(return_value=None) # type: ignore
705+
706+
adapter.insert_overwrite_by_time_partition(
707+
table_name=".".join(["my_catalog", "schema", "test_table"]),
708+
query_or_df=parse_one("SELECT a, b FROM tbl"),
709+
start="2022-01-01",
710+
end="2022-01-02",
711+
time_column="b",
712+
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
713+
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
714+
)
715+
716+
assert to_sql_calls(adapter) == [
717+
"SET SESSION my_catalog.insert_existing_partitions_behavior='OVERWRITE'",
718+
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
719+
"SET SESSION my_catalog.insert_existing_partitions_behavior='APPEND'",
720+
]
721+
722+
723+
@pytest.mark.parametrize(
724+
"catalog_type_overrides",
725+
[
726+
{"my_catalog": "iceberg"},
727+
{"my_catalog": "unknown"},
728+
],
729+
)
730+
def test_insert_overwrite_time_partition_iceberg(
731+
make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str]
732+
):
733+
config = TrinoConnectionConfig(
734+
user="user",
735+
host="host",
736+
catalog="catalog",
737+
catalog_type_overrides=catalog_type_overrides,
738+
)
739+
adapter: TrinoEngineAdapter = make_mocked_engine_adapter(
740+
TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides
741+
)
742+
adapter.fetchone = MagicMock(return_value=None) # type: ignore
743+
744+
adapter.insert_overwrite_by_time_partition(
745+
table_name=".".join(["my_catalog", "schema", "test_table"]),
746+
query_or_df=parse_one("SELECT a, b FROM tbl"),
747+
start="2022-01-01",
748+
end="2022-01-02",
749+
time_column="b",
750+
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
751+
target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
752+
)
753+
754+
assert to_sql_calls(adapter) == [
755+
'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
756+
'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'',
757+
]

tests/core/test_connection_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,25 @@ def test_trino_schema_location_mapping(make_config):
425425
assert all((isinstance(v, str) for v in config.schema_location_mapping.values()))
426426

427427

428+
def test_trino_catalog_type_override(make_config):
429+
required_kwargs = dict(
430+
type="trino",
431+
user="user",
432+
host="host",
433+
catalog="catalog",
434+
)
435+
436+
config = make_config(
437+
**required_kwargs,
438+
catalog_type_override={"my_catalog": "iceberg"},
439+
)
440+
441+
assert config.catalog_type_override is not None
442+
assert len(config.catalog_type_override) == 1
443+
444+
assert config.catalog_type_override == {"my_catalog": "iceberg"}
445+
446+
428447
def test_duckdb(make_config):
429448
config = make_config(
430449
type="duckdb",

0 commit comments

Comments
 (0)