Skip to content

Commit 7997477

Browse files
authored
Feat: add dbt trino support (#2223)
* Add dbt trino support * PR feedback
1 parent 899af66 commit 7997477

File tree

6 files changed

+205
-4
lines changed

6 files changed

+205
-4
lines changed

examples/sushi_dbt/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
config = sqlmesh_config(Path(__file__).parent)
77

8-
98
test_config = config
109

1110

examples/sushi_dbt/models/waiter_as_customer_by_day.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
}}
99

1010
SELECT
11-
w.ds as ds,
1211
w.waiter_id as waiter_id,
13-
wn.name as waiter_name
12+
wn.name as waiter_name,
13+
w.ds as ds,
1414
FROM {{ ref('waiters') }} AS w
1515
JOIN {{ ref('customers') }} as c ON w.waiter_id = c.customer_id
1616
JOIN {{ ref('waiter_names') }} as wn ON w.waiter_id = wn.id

examples/sushi_dbt/profiles.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,13 @@ sushi:
4444
schema: sushi
4545
threads: 1
4646
type: sqlserver
47+
trino:
48+
type: trino
49+
method: none
50+
http_scheme: http
51+
user: "{{ env_var('TRINO_USER') }}"
52+
host: "{{ env_var('TRINO_HOST') }}"
53+
database: "{{ env_var('TRINO_DATABASE') }}"
54+
schema: sushi
55+
threads: 1
4756
target: in_memory

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
"dbt-redshift",
115115
"dbt-snowflake",
116116
"dbt-sqlserver",
117+
"dbt-trino",
117118
],
118119
"dbt": [
119120
"dbt-core<2",

sqlmesh/dbt/target.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
PostgresConnectionConfig,
2222
RedshiftConnectionConfig,
2323
SnowflakeConnectionConfig,
24+
TrinoAuthenticationMethod,
25+
TrinoConnectionConfig,
2426
)
2527
from sqlmesh.core.model import (
2628
IncrementalByTimeRangeKind,
@@ -105,6 +107,8 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig:
105107
return BigQueryConfig(**data)
106108
elif db_type == "sqlserver":
107109
return MSSQLConfig(**data)
110+
elif db_type == "trino":
111+
return TrinoConfig(**data)
108112

109113
raise ConfigError(f"{db_type} not supported.")
110114

@@ -215,6 +219,11 @@ class SnowflakeConfig(TargetConfig):
215219
connect_timeout: Number of seconds to wait between failed attempts
216220
retry_on_database_errors: A boolean flag to retry if a Snowflake connector Database error is encountered
217221
retry_all: A boolean flag to retry on all Snowflake connector errors
222+
authenticator: SSO authentication: Snowflake authentication method
223+
private_key: Key pair authentication: Private key
224+
private_key_path: Key pair authentication: Path to the private key, used instead of private_key
225+
private_key_passphrase: Key pair authentication: passphrase used to decrypt private key (if encrypted)
226+
token: OAuth authentication: The Snowflake OAuth 2.0 access token
218227
"""
219228

220229
type: Literal["snowflake"] = "snowflake"
@@ -373,7 +382,6 @@ class RedshiftConfig(TargetConfig):
373382
password: User's password
374383
port: The port to connect to
375384
dbname: Name of the database
376-
keepalives_idle: Seconds between TCP keepalive packets
377385
connect_timeout: Number of seconds to wait between failed attempts
378386
ra3_node: Enables cross-database sources
379387
search_path: Overrides the default search path
@@ -495,6 +503,13 @@ class BigQueryConfig(TargetConfig):
495503
client_secret: The BigQuery client secret
496504
token_uri: The BigQuery token URI
497505
scopes: The BigQuery scopes
506+
job_execution_timeout_seconds: The maximum amount of time, in seconds, to wait for the underlying job to complete
507+
timeout_seconds: Alias for job_execution_timeout_seconds
508+
job_retries: The number of times to retry the underlying job if it fails
509+
retries: Alias for job_retries
510+
job_retry_deadline_seconds: Total number of seconds to wait while retrying the same query
511+
priority: The priority of the underlying job
512+
maximum_bytes_billed: The maximum number of bytes to be billed for the underlying job
498513
"""
499514

500515
type: Literal["bigquery"] = "bigquery"
@@ -587,11 +602,25 @@ class MSSQLConfig(TargetConfig):
587602
588603
Args:
589604
host: The MSSQL server host to connect to
605+
server: Alias for host
590606
port: The MSSQL server port to connect to
591607
user: User name for authentication
608+
username: Alias for user
609+
UID: Alias for user
592610
password: User password for authentication
611+
PWD: Alias for password
593612
login_timeout: The number of seconds to wait for a login to complete
594613
query_timeout: The number of seconds to wait for a query to complete
614+
authentication: The authentication method to use (only "sql" is supported)
615+
schema_authorization: The principal who should own created schemas, not supported by SQLMesh
616+
driver: ODBC driver to use, not used by SQLMesh
617+
encrypt: A boolean flag to enable server connection encryption, not used by SQLMesh
618+
trust_cert: A boolean flag to trust the server certificate, not used by SQLMesh
619+
retries: Number of times to retry if the SQL Server connector encounters an error, not used by SQLMesh
620+
windows_login: A boolean flag to use Windows Authentication, not used by SQLMesh
621+
tenant_id: The tenant ID of the Azure Active Directory instance, not used by SQLMesh
622+
client_id: The client ID of the Azure Active Directory service principal, not used by SQLMesh
623+
client_secret: The client secret of the Azure Active Directory service principal, not used by SQLMesh
595624
"""
596625

597626
type: Literal["sqlserver"] = "sqlserver"
@@ -682,6 +711,139 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
682711
)
683712

684713

714+
class TrinoConfig(TargetConfig):
715+
"""
716+
Project connection and operational configuration for the Trino target.
717+
718+
Args:
719+
method: The Trino authentication method to use
720+
host: The server host to connect to
721+
port: The MSSQL server port to connect to
722+
database: Name of the Trino database/catalog
723+
schema: Name of the Trino schema
724+
user: User name for authentication
725+
password: User password for authentication
726+
roles: Trino catalog roles
727+
session_properties: Trino session properties
728+
retries: Number of times to retry if the Trino connector encounters an error
729+
timezone: The timezone to use for the Trino session
730+
http_headers: HTTP Headers to send alongside requests to Trino
731+
http_scheme: The HTTP scheme to use for requests to Trino (default: http, or https if kerberos, ldap or jwt auth)
732+
threads: The number of threads to run on
733+
impersonation_user: LDAP authentication: override the provided username
734+
keytab: Kerberos authentication: Path to keytab
735+
krb5_config: Kerberos authentication: Path to config
736+
principal: Kerberos authentication: Principal
737+
service_name: Kerberos authentication: Service name
738+
hostname_override: Kerberos authentication: hostname for a host whose DNS name doesn't match
739+
mutual_authentication: Kerberos authentication: Boolean flag for mutual authentication.
740+
force_preemptive: Kerberos authentication: Boolean flag to preemptively initiate the GSS exchange.
741+
sanitize_mutual_error_response: Kerberos authentication: Boolean flag to strip content and headers from error responses.
742+
delegate: Kerberos authentication: Boolean flag for credential delegation (`GSS_C_DELEG_FLAG`)
743+
jwt_token: JWT authentication: JWT string
744+
client_certificate: Certification authentication: Path to client certificate
745+
client_private_key: Certification authentication: Path to client private key
746+
cert: Certification authentication: Full path to a certificate file
747+
"""
748+
749+
_method_to_auth_enum: t.ClassVar[t.Dict[str, TrinoAuthenticationMethod]] = {
750+
"none": TrinoAuthenticationMethod.NO_AUTH,
751+
"ldap": TrinoAuthenticationMethod.LDAP,
752+
"kerberos": TrinoAuthenticationMethod.KERBEROS,
753+
"jwt": TrinoAuthenticationMethod.JWT,
754+
"certificate": TrinoAuthenticationMethod.CERTIFICATE,
755+
"oauth": TrinoAuthenticationMethod.OAUTH,
756+
"oauth_console": TrinoAuthenticationMethod.OAUTH,
757+
}
758+
759+
type: Literal["trino"] = "trino"
760+
host: str
761+
database: str
762+
schema_: str = Field(alias="schema")
763+
port: int = 443
764+
method: str
765+
user: t.Optional[str] = None
766+
767+
threads: int = 1
768+
roles: t.Optional[t.Dict[str, str]] = None
769+
session_properties: t.Optional[t.Dict[str, str]] = None
770+
retries: int = 3
771+
timezone: t.Optional[str] = None
772+
http_headers: t.Optional[t.Dict[str, str]] = None
773+
http_scheme: t.Optional[str] = None
774+
prepared_statements_enabled: bool = True # not used by SQLMesh
775+
776+
# ldap authentication
777+
password: t.Optional[str] = None
778+
impersonation_user: t.Optional[str] = None
779+
780+
# kerberos authentication
781+
keytab: t.Optional[str] = None
782+
krb5_config: t.Optional[str] = None
783+
principal: t.Optional[str] = None
784+
service_name: str = "trino"
785+
hostname_override: t.Optional[str] = None
786+
mutual_authentication: bool = False
787+
force_preemptive: bool = False
788+
sanitize_mutual_error_response: bool = True
789+
delegate: bool = False
790+
791+
# jwt authentication
792+
jwt_token: t.Optional[str] = None
793+
794+
# certificate authentication
795+
client_certificate: t.Optional[str] = None
796+
client_private_key: t.Optional[str] = None
797+
cert: t.Optional[str] = None
798+
799+
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
800+
return "append"
801+
802+
@classproperty
803+
def relation_class(cls) -> t.Type[BaseRelation]:
804+
from dbt.adapters.trino.relation import TrinoRelation
805+
806+
return TrinoRelation
807+
808+
@classproperty
809+
def column_class(cls) -> t.Type[Column]:
810+
from dbt.adapters.trino.column import TrinoColumn
811+
812+
return TrinoColumn
813+
814+
def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
815+
return TrinoConnectionConfig(
816+
method=self._method_to_auth_enum[self.method],
817+
host=self.host,
818+
user=self.user,
819+
catalog=self.database,
820+
port=self.port,
821+
http_scheme=self.http_scheme,
822+
roles=self.roles,
823+
http_headers=self.http_headers,
824+
session_properties=self.session_properties,
825+
retries=self.retries,
826+
timezone=self.timezone,
827+
password=self.password,
828+
impersonation_user=self.impersonation_user,
829+
keytab=self.keytab,
830+
krb5_config=self.krb5_config,
831+
principal=self.principal,
832+
service_name=self.service_name,
833+
hostname_override=self.hostname_override,
834+
mutual_authentication=self.mutual_authentication,
835+
force_preemptive=self.force_preemptive,
836+
sanitize_mutual_error_response=self.sanitize_mutual_error_response,
837+
delegate=self.delegate,
838+
jwt_token=self.jwt_token,
839+
client_certificate=self.client_certificate,
840+
client_private_key=self.client_private_key,
841+
cert=self.cert,
842+
concurrent_tasks=self.threads,
843+
**kwargs,
844+
)
845+
846+
685847
TARGET_TYPE_TO_CONFIG_CLASS: t.Dict[str, t.Type[TargetConfig]] = {
686848
"databricks": DatabricksConfig,
687849
"duckdb": DuckDbConfig,
@@ -690,4 +852,5 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
690852
"snowflake": SnowflakeConfig,
691853
"bigquery": BigQueryConfig,
692854
"sqlserver": MSSQLConfig,
855+
"trino": TrinoConfig,
693856
}

tests/dbt/test_config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
RedshiftConfig,
2828
SnowflakeConfig,
2929
TargetConfig,
30+
TrinoConfig,
3031
)
3132
from sqlmesh.dbt.test import TestConfig
3233
from sqlmesh.utils.errors import ConfigError
@@ -670,6 +671,30 @@ def test_sqlserver_config():
670671
)
671672

672673

674+
def test_trino_config():
675+
_test_warehouse_config(
676+
"""
677+
dbt-trino:
678+
target: dev
679+
outputs:
680+
dev:
681+
type: trino
682+
method: ldap
683+
user: user
684+
password: password
685+
host: localhost
686+
database: database
687+
schema: dbt_schema
688+
port: 443
689+
threads: 1
690+
""",
691+
TrinoConfig,
692+
"dbt-trino",
693+
"outputs",
694+
"dev",
695+
)
696+
697+
673698
def test_connection_args(tmp_path):
674699
dbt_project_dir = "tests/fixtures/dbt/sushi_test"
675700

@@ -687,12 +712,14 @@ def test_db_type_to_relation_class():
687712
from dbt.adapters.duckdb.relation import DuckDBRelation
688713
from dbt.adapters.redshift import RedshiftRelation
689714
from dbt.adapters.snowflake import SnowflakeRelation
715+
from dbt.adapters.trino.relation import TrinoRelation
690716

691717
assert (TARGET_TYPE_TO_CONFIG_CLASS["bigquery"].relation_class) == BigQueryRelation
692718
assert (TARGET_TYPE_TO_CONFIG_CLASS["databricks"].relation_class) == DatabricksRelation
693719
assert (TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].relation_class) == DuckDBRelation
694720
assert (TARGET_TYPE_TO_CONFIG_CLASS["redshift"].relation_class) == RedshiftRelation
695721
assert (TARGET_TYPE_TO_CONFIG_CLASS["snowflake"].relation_class) == SnowflakeRelation
722+
assert (TARGET_TYPE_TO_CONFIG_CLASS["trino"].relation_class) == TrinoRelation
696723

697724

698725
@pytest.mark.cicdonly
@@ -701,12 +728,14 @@ def test_db_type_to_column_class():
701728
from dbt.adapters.databricks.column import DatabricksColumn
702729
from dbt.adapters.snowflake import SnowflakeColumn
703730
from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn
731+
from dbt.adapters.trino.column import TrinoColumn
704732

705733
assert (TARGET_TYPE_TO_CONFIG_CLASS["bigquery"].column_class) == BigQueryColumn
706734
assert (TARGET_TYPE_TO_CONFIG_CLASS["databricks"].column_class) == DatabricksColumn
707735
assert (TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].column_class) == Column
708736
assert (TARGET_TYPE_TO_CONFIG_CLASS["snowflake"].column_class) == SnowflakeColumn
709737
assert (TARGET_TYPE_TO_CONFIG_CLASS["sqlserver"].column_class) == SQLServerColumn
738+
assert (TARGET_TYPE_TO_CONFIG_CLASS["trino"].column_class) == TrinoColumn
710739

711740

712741
def test_db_type_to_quote_policy():

0 commit comments

Comments
 (0)