Skip to content

Commit bb57ed9

Browse files
authored
Feat: Support DBT Athena adapter (#3222)
1 parent e4b2817 commit bb57ed9

File tree

5 files changed

+291
-37
lines changed

5 files changed

+291
-37
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
"pytz",
8787
"snowflake-connector-python[pandas,secure-local-storage]>=3.0.2",
8888
"sqlalchemy-stubs",
89-
"tenacity==8.1.0",
89+
"tenacity",
9090
"types-croniter",
9191
"types-dateparser",
9292
"types-python-dateutil",
@@ -100,6 +100,7 @@
100100
"dbt-redshift",
101101
"dbt-sqlserver>=1.7.0",
102102
"dbt-trino",
103+
"dbt-athena-community",
103104
],
104105
"dbt": [
105106
"dbt-core<2",

sqlmesh/core/engine_adapter/athena.py

Lines changed: 75 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
if t.TYPE_CHECKING:
2424
from sqlmesh.core._typing import SchemaName, TableName
25+
from sqlmesh.core.engine_adapter._typing import QueryOrDF
26+
27+
TableType = t.Union[t.Literal["hive"], t.Literal["iceberg"]]
2528

2629
logger = logging.getLogger(__name__)
2730

@@ -30,8 +33,10 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin):
3033
DIALECT = "athena"
3134
SUPPORTS_TRANSACTIONS = False
3235
SUPPORTS_REPLACE_TABLE = False
33-
# Athena has the concept of catalogs but no notion of current_catalog or setting the current catalog
34-
CATALOG_SUPPORT = CatalogSupport.UNSUPPORTED
36+
# Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that
37+
# It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog"
38+
# are pointers to the "awsdatacatalog" of other AWS accounts
39+
CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY
3540
# Athena's support for table and column comments is too patchy to consider "supported"
3641
# Hive tables: Table + Column comments are supported
3742
# Iceberg tables: Column comments only
@@ -48,6 +53,8 @@ def __init__(
4853
super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
4954
self.s3_warehouse_location = s3_warehouse_location
5055

56+
self._default_catalog = self._default_catalog or "awsdatacatalog"
57+
5158
@property
5259
def s3_warehouse_location(self) -> t.Optional[str]:
5360
return self._s3_warehouse_location
@@ -90,14 +97,7 @@ def _get_data_objects(
9097
schema = schema_name.db
9198
query = (
9299
exp.select(
93-
exp.case()
94-
.when(
95-
# calling code expects data objects in the default catalog to have their catalog set to None
96-
exp.column("table_catalog", table="t").eq("awsdatacatalog"),
97-
exp.Null(),
98-
)
99-
.else_(exp.column("table_catalog"))
100-
.as_("catalog"),
100+
exp.column("table_catalog").as_("catalog"),
101101
exp.column("table_schema", table="t").as_("schema"),
102102
exp.column("table_name", table="t").as_("name"),
103103
exp.case()
@@ -130,6 +130,7 @@ def columns(
130130
self, table_name: TableName, include_pseudo_columns: bool = False
131131
) -> t.Dict[str, exp.DataType]:
132132
table = exp.to_table(table_name)
133+
# note: the data_type column contains the full parameterized type, eg 'varchar(10)'
133134
query = (
134135
exp.select("column_name", "data_type")
135136
.from_("information_schema.columns")
@@ -305,24 +306,29 @@ def _build_table_properties_exp(
305306

306307
return None
307308

309+
def drop_table(self, table_name: TableName, exists: bool = True) -> None:
310+
table = exp.to_table(table_name)
311+
312+
if self._query_table_type(table) == "hive":
313+
self._truncate_table(table)
314+
315+
return super().drop_table(table_name=table, exists=exists)
316+
308317
def _truncate_table(self, table_name: TableName) -> None:
309-
if isinstance(table_name, str):
310-
table_name = exp.to_table(table_name)
318+
table = exp.to_table(table_name)
311319

312320
# Truncating an Iceberg table is just DELETE FROM <table>
313-
if self._query_table_type(table_name) == "iceberg":
314-
return self.delete_from(table_name, exp.true())
321+
if self._query_table_type(table) == "iceberg":
322+
return self.delete_from(table, exp.true())
315323

316324
# Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3
317-
if self._is_hive_partitioned_table(table_name):
318-
self._clear_partition_data(table_name, exp.true())
319-
elif s3_location := self._query_table_s3_location(table_name):
325+
if self._is_hive_partitioned_table(table):
326+
self._clear_partition_data(table, exp.true())
327+
elif s3_location := self._query_table_s3_location(table):
320328
# Truncating a non-partitioned Hive table is clearing out all data in its Location
321329
self._clear_s3_location(s3_location)
322330

323-
def _table_type(
324-
self, table_format: t.Optional[str] = None
325-
) -> t.Union[t.Literal["hive"], t.Literal["iceberg"]]:
331+
def _table_type(self, table_format: t.Optional[str] = None) -> TableType:
326332
"""
327333
Interpret the "table_format" property to check if this is a Hive or an Iceberg table
328334
"""
@@ -332,12 +338,19 @@ def _table_type(
332338
# if we cant detect any indication of Iceberg, this is a Hive table
333339
return "hive"
334340

341+
def _query_table_type(self, table: exp.Table) -> t.Optional[TableType]:
342+
if self.table_exists(table):
343+
return self._query_table_type_or_raise(table)
344+
return None
345+
335346
@lru_cache()
336-
def _query_table_type(
337-
self, table: exp.Table
338-
) -> t.Union[t.Literal["hive"], t.Literal["iceberg"]]:
347+
def _query_table_type_or_raise(self, table: exp.Table) -> TableType:
339348
"""
340-
Hit the DB to check if this is a Hive or an Iceberg table
349+
Hit the DB to check if this is a Hive or an Iceberg table.
350+
351+
Note that in order to @lru_cache() this method, we have the following assumptions:
352+
- The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation)
353+
- The table type will not change within the same SQLMesh session
341354
"""
342355
# Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here
343356
# This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks)
@@ -404,6 +417,29 @@ def _find_matching_columns(
404417
matches.append((key, match_dtype))
405418
return matches
406419

420+
def replace_query(
421+
self,
422+
table_name: TableName,
423+
query_or_df: QueryOrDF,
424+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
425+
table_description: t.Optional[str] = None,
426+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
427+
**kwargs: t.Any,
428+
) -> None:
429+
table = exp.to_table(table_name)
430+
431+
if self._query_table_type(table=table) == "hive":
432+
self.drop_table(table)
433+
434+
return super().replace_query(
435+
table_name=table,
436+
query_or_df=query_or_df,
437+
columns_to_types=columns_to_types,
438+
table_description=table_description,
439+
column_descriptions=column_descriptions,
440+
**kwargs,
441+
)
442+
407443
def _insert_overwrite_by_time_partition(
408444
self,
409445
table_name: TableName,
@@ -412,23 +448,22 @@ def _insert_overwrite_by_time_partition(
412448
where: exp.Condition,
413449
**kwargs: t.Any,
414450
) -> None:
415-
if isinstance(table_name, str):
416-
table_name = exp.to_table(table_name)
451+
table = exp.to_table(table_name)
417452

418-
table_type = self._query_table_type(table_name)
453+
table_type = self._query_table_type(table)
419454

420455
if table_type == "iceberg":
421456
# Iceberg tables work as expected, we can use the default behaviour
422457
return super()._insert_overwrite_by_time_partition(
423-
table_name, source_queries, columns_to_types, where, **kwargs
458+
table, source_queries, columns_to_types, where, **kwargs
424459
)
425460

426461
# For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3
427-
self._clear_partition_data(table_name, where)
462+
self._clear_partition_data(table, where)
428463

429464
# Now the data is physically gone, we can continue with inserting a new partition
430465
return super()._insert_overwrite_by_time_partition(
431-
table_name,
466+
table,
432467
source_queries,
433468
columns_to_types,
434469
where,
@@ -500,21 +535,20 @@ def _drop_partitions_from_metastore(
500535
)
501536

502537
def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None:
503-
if isinstance(table_name, str):
504-
table_name = exp.to_table(table_name)
538+
table = exp.to_table(table_name)
505539

506-
table_type = self._query_table_type(table_name)
540+
table_type = self._query_table_type(table)
507541

508542
# If Iceberg, DELETE operations work as expected
509543
if table_type == "iceberg":
510-
return super().delete_from(table_name, where)
544+
return super().delete_from(table, where)
511545

512546
# If Hive, DELETE is an error
513547
if table_type == "hive":
514548
# However, if there are no actual records to delete, we can make DELETE a no-op
515549
# This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine)
516550
empty_check = (
517-
exp.select("*").from_(table_name).where(where).limit(1)
551+
exp.select("*").from_(table).where(where).limit(1)
518552
) # deliberately not count(*) because we want the engine to stop as soon as it finds a record
519553
if len(self.fetchall(empty_check)) > 0:
520554
raise SQLMeshError("Cannot delete individual records from a Hive table")
@@ -536,7 +570,9 @@ def _clear_s3_location(self, s3_uri: str) -> None:
536570
Bucket=bucket, Prefix=key, Delimiter="/"
537571
):
538572
# list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time
539-
keys_to_delete.append([item["Key"] for item in page.get("Contents", [])])
573+
keys = [item["Key"] for item in page.get("Contents", [])]
574+
if keys:
575+
keys_to_delete.append(keys)
540576

541577
for chunk in keys_to_delete:
542578
s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
@@ -558,3 +594,6 @@ def _boto3_client(self, name: str) -> t.Any:
558594
config=conn.config,
559595
**conn._client_kwargs,
560596
) # type: ignore
597+
598+
def get_current_catalog(self) -> t.Optional[str]:
599+
return self.connection.catalog_name

sqlmesh/dbt/target.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SnowflakeConnectionConfig,
2323
TrinoAuthenticationMethod,
2424
TrinoConnectionConfig,
25+
AthenaConnectionConfig,
2526
)
2627
from sqlmesh.core.model import (
2728
IncrementalByTimeRangeKind,
@@ -109,6 +110,8 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig:
109110
return MSSQLConfig(**data)
110111
elif db_type == "trino":
111112
return TrinoConfig(**data)
113+
elif db_type == "athena":
114+
return AthenaConfig(**data)
112115

113116
raise ConfigError(f"{db_type} not supported.")
114117

@@ -849,6 +852,89 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
849852
)
850853

851854

855+
class AthenaConfig(TargetConfig):
856+
"""
857+
Project connection and operational configuration for the Athena target.
858+
859+
Args:
860+
s3_staging_dir: S3 location to store Athena query results and metadata
861+
s3_data_dir: Prefix for storing tables, if different from the connection's s3_staging_dir
862+
s3_data_naming: How to generate table paths in s3_data_dir
863+
s3_tmp_table_dir: Prefix for storing temporary tables, if different from the connection's s3_data_dir
864+
region_name: AWS region of your Athena instance
865+
schema: Specify the schema (Athena database) to build models into (lowercase only)
866+
database: Specify the database (Data catalog) to build models into (lowercase only)
867+
poll_interval: Interval in seconds to use for polling the status of query results in Athena
868+
debug_query_state: Flag if debug message with Athena query state is needed
869+
aws_access_key_id: Access key ID of the user performing requests
870+
aws_secret_access_key: Secret access key of the user performing requests
871+
aws_profile_name: Profile to use from your AWS shared credentials file
872+
work_group: Identifier of Athena workgroup
873+
skip_workgroup_check: Indicates if the WorkGroup check (additional AWS call) can be skipped
874+
num_retries: Number of times to retry a failing query
875+
num_boto3_retries: Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables)
876+
num_iceberg_retries: Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR
877+
spark_work_group: Identifier of Athena Spark workgroup for running Python models
878+
seed_s3_upload_args: Dictionary containing boto3 ExtraArgs when uploading to S3
879+
lf_tags_database: Default LF tags for new database if it's created by dbt
880+
"""
881+
882+
type: Literal["athena"] = "athena"
883+
threads: int = 4
884+
885+
s3_staging_dir: t.Optional[str] = None
886+
s3_data_dir: t.Optional[str] = None
887+
s3_data_naming: t.Optional[str] = None
888+
s3_tmp_table_dir: t.Optional[str] = None
889+
poll_interval: t.Optional[int] = None
890+
debug_query_state: bool = False
891+
work_group: t.Optional[str] = None
892+
skip_workgroup_check: t.Optional[bool] = None
893+
spark_work_group: t.Optional[str] = None
894+
895+
aws_access_key_id: t.Optional[str] = None
896+
aws_secret_access_key: t.Optional[str] = None
897+
aws_profile_name: t.Optional[str] = None
898+
region_name: t.Optional[str] = None
899+
900+
num_retries: t.Optional[int] = None
901+
num_boto3_retries: t.Optional[int] = None
902+
num_iceberg_retries: t.Optional[int] = None
903+
904+
seed_s3_upload_args: t.Dict[str, str] = {}
905+
lf_tags_database: t.Dict[str, str] = {}
906+
907+
@classproperty
908+
def relation_class(cls) -> t.Type[BaseRelation]:
909+
from dbt.adapters.athena.relation import AthenaRelation
910+
911+
return AthenaRelation
912+
913+
@classproperty
914+
def column_class(cls) -> t.Type[Column]:
915+
from dbt.adapters.athena.column import AthenaColumn
916+
917+
return AthenaColumn
918+
919+
def default_incremental_strategy(self, kind: IncrementalKind) -> str:
920+
return "insert_overwrite"
921+
922+
def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
923+
return AthenaConnectionConfig(
924+
type="athena",
925+
aws_access_key_id=self.aws_access_key_id,
926+
aws_secret_access_key=self.aws_secret_access_key,
927+
region_name=self.region_name,
928+
work_group=self.work_group,
929+
s3_staging_dir=self.s3_staging_dir,
930+
s3_warehouse_location=self.s3_data_dir,
931+
schema_name=self.schema_,
932+
catalog_name=self.database,
933+
concurrent_tasks=self.threads,
934+
**kwargs,
935+
)
936+
937+
852938
TARGET_TYPE_TO_CONFIG_CLASS: t.Dict[str, t.Type[TargetConfig]] = {
853939
"databricks": DatabricksConfig,
854940
"duckdb": DuckDbConfig,
@@ -859,4 +945,5 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig:
859945
"sqlserver": MSSQLConfig,
860946
"tsql": MSSQLConfig,
861947
"trino": TrinoConfig,
948+
"athena": AthenaConfig,
862949
}

0 commit comments

Comments
 (0)