2222
2323if t .TYPE_CHECKING :
2424 from sqlmesh .core ._typing import SchemaName , TableName
25+ from sqlmesh .core .engine_adapter ._typing import QueryOrDF
26+
27+ TableType = t .Union [t .Literal ["hive" ], t .Literal ["iceberg" ]]
2528
2629logger = logging .getLogger (__name__ )
2730
@@ -30,8 +33,10 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin):
3033 DIALECT = "athena"
3134 SUPPORTS_TRANSACTIONS = False
3235 SUPPORTS_REPLACE_TABLE = False
33- # Athena has the concept of catalogs but no notion of current_catalog or setting the current catalog
34- CATALOG_SUPPORT = CatalogSupport .UNSUPPORTED
36+ # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that
37+ # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog"
38+ # are pointers to the "awsdatacatalog" of other AWS accounts
39+ CATALOG_SUPPORT = CatalogSupport .SINGLE_CATALOG_ONLY
3540 # Athena's support for table and column comments is too patchy to consider "supported"
3641 # Hive tables: Table + Column comments are supported
3742 # Iceberg tables: Column comments only
@@ -48,6 +53,8 @@ def __init__(
4853 super ().__init__ (* args , s3_warehouse_location = s3_warehouse_location , ** kwargs )
4954 self .s3_warehouse_location = s3_warehouse_location
5055
56+ self ._default_catalog = self ._default_catalog or "awsdatacatalog"
57+
5158 @property
5259 def s3_warehouse_location (self ) -> t .Optional [str ]:
5360 return self ._s3_warehouse_location
@@ -90,14 +97,7 @@ def _get_data_objects(
9097 schema = schema_name .db
9198 query = (
9299 exp .select (
93- exp .case ()
94- .when (
95- # calling code expects data objects in the default catalog to have their catalog set to None
96- exp .column ("table_catalog" , table = "t" ).eq ("awsdatacatalog" ),
97- exp .Null (),
98- )
99- .else_ (exp .column ("table_catalog" ))
100- .as_ ("catalog" ),
100+ exp .column ("table_catalog" ).as_ ("catalog" ),
101101 exp .column ("table_schema" , table = "t" ).as_ ("schema" ),
102102 exp .column ("table_name" , table = "t" ).as_ ("name" ),
103103 exp .case ()
@@ -130,6 +130,7 @@ def columns(
130130 self , table_name : TableName , include_pseudo_columns : bool = False
131131 ) -> t .Dict [str , exp .DataType ]:
132132 table = exp .to_table (table_name )
133+ # note: the data_type column contains the full parameterized type, eg 'varchar(10)'
133134 query = (
134135 exp .select ("column_name" , "data_type" )
135136 .from_ ("information_schema.columns" )
@@ -305,24 +306,29 @@ def _build_table_properties_exp(
305306
306307 return None
307308
309+ def drop_table (self , table_name : TableName , exists : bool = True ) -> None :
310+ table = exp .to_table (table_name )
311+
312+ if self ._query_table_type (table ) == "hive" :
313+ self ._truncate_table (table )
314+
315+ return super ().drop_table (table_name = table , exists = exists )
316+
308317 def _truncate_table (self , table_name : TableName ) -> None :
309- if isinstance (table_name , str ):
310- table_name = exp .to_table (table_name )
318+ table = exp .to_table (table_name )
311319
312320 # Truncating an Iceberg table is just DELETE FROM <table>
313- if self ._query_table_type (table_name ) == "iceberg" :
314- return self .delete_from (table_name , exp .true ())
321+ if self ._query_table_type (table ) == "iceberg" :
322+ return self .delete_from (table , exp .true ())
315323
316324 # Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3
317- if self ._is_hive_partitioned_table (table_name ):
318- self ._clear_partition_data (table_name , exp .true ())
319- elif s3_location := self ._query_table_s3_location (table_name ):
325+ if self ._is_hive_partitioned_table (table ):
326+ self ._clear_partition_data (table , exp .true ())
327+ elif s3_location := self ._query_table_s3_location (table ):
320328 # Truncating a non-partitioned Hive table is clearing out all data in its Location
321329 self ._clear_s3_location (s3_location )
322330
323- def _table_type (
324- self , table_format : t .Optional [str ] = None
325- ) -> t .Union [t .Literal ["hive" ], t .Literal ["iceberg" ]]:
331+ def _table_type (self , table_format : t .Optional [str ] = None ) -> TableType :
326332 """
327333 Interpret the "table_format" property to check if this is a Hive or an Iceberg table
328334 """
@@ -332,12 +338,19 @@ def _table_type(
332338 # if we cant detect any indication of Iceberg, this is a Hive table
333339 return "hive"
334340
341+ def _query_table_type (self , table : exp .Table ) -> t .Optional [TableType ]:
342+ if self .table_exists (table ):
343+ return self ._query_table_type_or_raise (table )
344+ return None
345+
335346 @lru_cache ()
336- def _query_table_type (
337- self , table : exp .Table
338- ) -> t .Union [t .Literal ["hive" ], t .Literal ["iceberg" ]]:
347+ def _query_table_type_or_raise (self , table : exp .Table ) -> TableType :
339348 """
340- Hit the DB to check if this is a Hive or an Iceberg table
349+ Hit the DB to check if this is a Hive or an Iceberg table.
350+
351+ Note that in order to @lru_cache() this method, we have the following assumptions:
352+ - The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation)
353+ - The table type will not change within the same SQLMesh session
341354 """
342355 # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here
343356 # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks)
@@ -404,6 +417,29 @@ def _find_matching_columns(
404417 matches .append ((key , match_dtype ))
405418 return matches
406419
420+ def replace_query (
421+ self ,
422+ table_name : TableName ,
423+ query_or_df : QueryOrDF ,
424+ columns_to_types : t .Optional [t .Dict [str , exp .DataType ]] = None ,
425+ table_description : t .Optional [str ] = None ,
426+ column_descriptions : t .Optional [t .Dict [str , str ]] = None ,
427+ ** kwargs : t .Any ,
428+ ) -> None :
429+ table = exp .to_table (table_name )
430+
431+ if self ._query_table_type (table = table ) == "hive" :
432+ self .drop_table (table )
433+
434+ return super ().replace_query (
435+ table_name = table ,
436+ query_or_df = query_or_df ,
437+ columns_to_types = columns_to_types ,
438+ table_description = table_description ,
439+ column_descriptions = column_descriptions ,
440+ ** kwargs ,
441+ )
442+
407443 def _insert_overwrite_by_time_partition (
408444 self ,
409445 table_name : TableName ,
@@ -412,23 +448,22 @@ def _insert_overwrite_by_time_partition(
412448 where : exp .Condition ,
413449 ** kwargs : t .Any ,
414450 ) -> None :
415- if isinstance (table_name , str ):
416- table_name = exp .to_table (table_name )
451+ table = exp .to_table (table_name )
417452
418- table_type = self ._query_table_type (table_name )
453+ table_type = self ._query_table_type (table )
419454
420455 if table_type == "iceberg" :
421456 # Iceberg tables work as expected, we can use the default behaviour
422457 return super ()._insert_overwrite_by_time_partition (
423- table_name , source_queries , columns_to_types , where , ** kwargs
458+ table , source_queries , columns_to_types , where , ** kwargs
424459 )
425460
426461 # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3
427- self ._clear_partition_data (table_name , where )
462+ self ._clear_partition_data (table , where )
428463
429464 # Now the data is physically gone, we can continue with inserting a new partition
430465 return super ()._insert_overwrite_by_time_partition (
431- table_name ,
466+ table ,
432467 source_queries ,
433468 columns_to_types ,
434469 where ,
@@ -500,21 +535,20 @@ def _drop_partitions_from_metastore(
500535 )
501536
502537 def delete_from (self , table_name : TableName , where : t .Union [str , exp .Expression ]) -> None :
503- if isinstance (table_name , str ):
504- table_name = exp .to_table (table_name )
538+ table = exp .to_table (table_name )
505539
506- table_type = self ._query_table_type (table_name )
540+ table_type = self ._query_table_type (table )
507541
508542 # If Iceberg, DELETE operations work as expected
509543 if table_type == "iceberg" :
510- return super ().delete_from (table_name , where )
544+ return super ().delete_from (table , where )
511545
512546 # If Hive, DELETE is an error
513547 if table_type == "hive" :
514548 # However, if there are no actual records to delete, we can make DELETE a no-op
515549 # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine)
516550 empty_check = (
517- exp .select ("*" ).from_ (table_name ).where (where ).limit (1 )
551+ exp .select ("*" ).from_ (table ).where (where ).limit (1 )
518552 ) # deliberately not count(*) because we want the engine to stop as soon as it finds a record
519553 if len (self .fetchall (empty_check )) > 0 :
520554 raise SQLMeshError ("Cannot delete individual records from a Hive table" )
@@ -536,7 +570,9 @@ def _clear_s3_location(self, s3_uri: str) -> None:
536570 Bucket = bucket , Prefix = key , Delimiter = "/"
537571 ):
538572 # list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time
539- keys_to_delete .append ([item ["Key" ] for item in page .get ("Contents" , [])])
573+ keys = [item ["Key" ] for item in page .get ("Contents" , [])]
574+ if keys :
575+ keys_to_delete .append (keys )
540576
541577 for chunk in keys_to_delete :
542578 s3 .delete_objects (Bucket = bucket , Delete = {"Objects" : [{"Key" : k } for k in chunk ]})
@@ -558,3 +594,6 @@ def _boto3_client(self, name: str) -> t.Any:
558594 config = conn .config ,
559595 ** conn ._client_kwargs ,
560596 ) # type: ignore
597+
598+ def get_current_catalog (self ) -> t .Optional [str ]:
599+ return self .connection .catalog_name
0 commit comments