55from collections import defaultdict
66
77import pandas as pd
8- from sqlglot import exp
8+ from sqlglot import exp , parse_one
99from sqlglot .transforms import remove_precision_parameterized_types
1010
1111from sqlmesh .core .dialect import to_schema
4040NestedField = t .Tuple [str , str , t .List [str ]]
4141NestedFieldsDict = t .Dict [str , t .List [NestedField ]]
4242
43+ # used to tag AST nodes to be specially handled in alter_table()
44+ _CLUSTERING_META_KEY = "__sqlmesh_update_table_clustering"
45+
4346
4447@set_catalog ()
4548class BigQueryEngineAdapter (InsertOverwriteWithMergeMixin , ClusteredByMixin ):
@@ -243,6 +246,18 @@ def alter_table(
243246 if nested_fields :
244247 self ._update_table_schema_nested_fields (nested_fields , alter_expressions [0 ].this )
245248
249+ # this is easier than trying to detect exp.Cluster nodes
250+ # or exp.Command nodes that contain the string "DROP CLUSTERING KEY"
251+ clustering_change_operations = [
252+ e for e in non_nested_expressions if _CLUSTERING_META_KEY in e .meta
253+ ]
254+ for op in clustering_change_operations :
255+ non_nested_expressions .remove (op )
256+ table , cluster_by = op .meta [_CLUSTERING_META_KEY ]
257+ assert isinstance (table , str ) or isinstance (table , exp .Table )
258+
259+ self ._update_clustering_key (table , cluster_by )
260+
246261 if non_nested_expressions :
247262 super ().alter_table (non_nested_expressions )
248263
@@ -847,25 +862,55 @@ def _get_data_objects(
847862 # resort to using SQL instead.
848863 schema = to_schema (schema_name )
849864 catalog = schema .catalog or self .default_catalog
850- query = exp .select (
851- exp .column ("table_catalog" ).as_ ("catalog" ),
852- exp .column ("table_name" ).as_ ("name" ),
853- exp .column ("table_schema" ).as_ ("schema_name" ),
854- exp .case ()
855- .when (exp .column ("table_type" ).eq ("BASE TABLE" ), exp .Literal .string ("TABLE" ))
856- .when (exp .column ("table_type" ).eq ("CLONE" ), exp .Literal .string ("TABLE" ))
857- .when (exp .column ("table_type" ).eq ("EXTERNAL" ), exp .Literal .string ("TABLE" ))
858- .when (exp .column ("table_type" ).eq ("SNAPSHOT" ), exp .Literal .string ("TABLE" ))
859- .when (exp .column ("table_type" ).eq ("VIEW" ), exp .Literal .string ("VIEW" ))
860- .when (
861- exp .column ("table_type" ).eq ("MATERIALIZED VIEW" ),
862- exp .Literal .string ("MATERIALIZED_VIEW" ),
865+ query = (
866+ exp .select (
867+ exp .column ("table_catalog" ).as_ ("catalog" ),
868+ exp .column ("table_name" ).as_ ("name" ),
869+ exp .column ("table_schema" ).as_ ("schema_name" ),
870+ exp .case ()
871+ .when (exp .column ("table_type" ).eq ("BASE TABLE" ), exp .Literal .string ("TABLE" ))
872+ .when (exp .column ("table_type" ).eq ("CLONE" ), exp .Literal .string ("TABLE" ))
873+ .when (exp .column ("table_type" ).eq ("EXTERNAL" ), exp .Literal .string ("TABLE" ))
874+ .when (exp .column ("table_type" ).eq ("SNAPSHOT" ), exp .Literal .string ("TABLE" ))
875+ .when (exp .column ("table_type" ).eq ("VIEW" ), exp .Literal .string ("VIEW" ))
876+ .when (
877+ exp .column ("table_type" ).eq ("MATERIALIZED VIEW" ),
878+ exp .Literal .string ("MATERIALIZED_VIEW" ),
879+ )
880+ .else_ (exp .column ("table_type" ))
881+ .as_ ("type" ),
882+ exp .column ("clustering_key" , "ci" ).as_ ("clustering_key" ),
883+ )
884+ .with_ (
885+ "clustering_info" ,
886+ as_ = exp .select (
887+ exp .column ("table_catalog" ),
888+ exp .column ("table_schema" ),
889+ exp .column ("table_name" ),
890+ parse_one (
891+ "string_agg(column_name order by clustering_ordinal_position)" ,
892+ dialect = self .dialect ,
893+ ).as_ ("clustering_key" ),
894+ )
895+ .from_ (
896+ exp .to_table (
897+ f"`{ catalog } `.`{ schema .db } `.INFORMATION_SCHEMA.COLUMNS" ,
898+ dialect = self .dialect ,
899+ )
900+ )
901+ .where (exp .column ("clustering_ordinal_position" ).is_ (exp .not_ (exp .null ())))
902+ .group_by ("1" , "2" , "3" ),
863903 )
864- .else_ (exp .column ("table_type" ))
865- .as_ ("type" ),
866- ).from_ (
867- exp .to_table (
868- f"`{ catalog } `.`{ schema .db } `.INFORMATION_SCHEMA.TABLES" , dialect = self .dialect
904+ .from_ (
905+ exp .to_table (
906+ f"`{ catalog } `.`{ schema .db } `.INFORMATION_SCHEMA.TABLES" , dialect = self .dialect
907+ )
908+ )
909+ .join (
910+ "clustering_info" ,
911+ using = ["table_catalog" , "table_schema" , "table_name" ],
912+ join_type = "left" ,
913+ join_alias = "ci" ,
869914 )
870915 )
871916 if object_names :
@@ -886,10 +931,41 @@ def _get_data_objects(
886931 schema = row .schema_name , # type: ignore
887932 name = row .name , # type: ignore
888933 type = DataObjectType .from_str (row .type ), # type: ignore
934+ clustering_key = f"({ row .clustering_key } )" if row .clustering_key else None , # type: ignore
889935 )
890936 for row in df .itertuples ()
891937 ]
892938
939+ def _change_clustering_key_expr (
940+ self , table : exp .Table , cluster_by : t .List [exp .Expression ]
941+ ) -> exp .Alter :
942+ expr = super ()._change_clustering_key_expr (table = table , cluster_by = cluster_by )
943+ expr .meta [_CLUSTERING_META_KEY ] = (table , cluster_by )
944+ return expr
945+
946+ def _drop_clustering_key_expr (self , table : exp .Table ) -> exp .Alter :
947+ expr = super ()._drop_clustering_key_expr (table = table )
948+ expr .meta [_CLUSTERING_META_KEY ] = (table , None )
949+ return expr
950+
951+ def _update_clustering_key (
952+ self , table_name : TableName , cluster_by : t .Optional [t .List [exp .Expression ]]
953+ ) -> None :
954+ cluster_by = cluster_by or []
955+ bq_table = self ._get_table (table_name )
956+
957+ rendered_columns = [c .sql (dialect = self .dialect ) for c in cluster_by ]
958+ bq_table .clustering_fields = (
959+ rendered_columns or None
960+ ) # causes a drop of the key if cluster_by is empty or None
961+
962+ self ._db_call (self .client .update_table , table = bq_table , fields = ["clustering_fields" ])
963+
964+ if cluster_by :
965+ # BigQuery only applies new clustering going forward, so this rewrites the columns to apply the new clustering to historical data
966+ # ref: https://cloud.google.com/bigquery/docs/creating-clustered-tables#modifying-cluster-spec
967+ self .execute (exp .update (table_name , {c : c for c in cluster_by }, where = exp .true ()))
968+
893969 @property
894970 def _query_data (self ) -> t .Any :
895971 return self ._connection_pool .get_attribute ("query_data" )
@@ -971,7 +1047,7 @@ def select_partitions_expr(
9711047 """Generates a SQL expression that aggregates partition values for a table.
9721048
9731049 Args:
974- schema: The schema (BigQueyr dataset) of the table.
1050+ schema: The schema (BigQuery dataset) of the table.
9751051 table_name: The name of the table.
9761052 data_type: The data type of the partition column.
9771053 granularity: The granularity of the partition. Supported values are: 'day', 'month', 'year' and 'hour'.
0 commit comments