Skip to content

Commit 9300f27

Browse files
authored
Feat: BigQuery - Handle forward_only changes to clustered_by (#3231)
1 parent 52132d8 commit 9300f27

File tree

6 files changed

+351
-85
lines changed

6 files changed

+351
-85
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 96 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66

77
import pandas as pd
8-
from sqlglot import exp
8+
from sqlglot import exp, parse_one
99
from sqlglot.transforms import remove_precision_parameterized_types
1010

1111
from sqlmesh.core.dialect import to_schema
@@ -40,6 +40,9 @@
4040
NestedField = t.Tuple[str, str, t.List[str]]
4141
NestedFieldsDict = t.Dict[str, t.List[NestedField]]
4242

43+
# used to tag AST nodes to be specially handled in alter_table()
44+
_CLUSTERING_META_KEY = "__sqlmesh_update_table_clustering"
45+
4346

4447
@set_catalog()
4548
class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin):
@@ -243,6 +246,18 @@ def alter_table(
243246
if nested_fields:
244247
self._update_table_schema_nested_fields(nested_fields, alter_expressions[0].this)
245248

249+
# this is easier than trying to detect exp.Cluster nodes
250+
# or exp.Command nodes that contain the string "DROP CLUSTERING KEY"
251+
clustering_change_operations = [
252+
e for e in non_nested_expressions if _CLUSTERING_META_KEY in e.meta
253+
]
254+
for op in clustering_change_operations:
255+
non_nested_expressions.remove(op)
256+
table, cluster_by = op.meta[_CLUSTERING_META_KEY]
257+
assert isinstance(table, str) or isinstance(table, exp.Table)
258+
259+
self._update_clustering_key(table, cluster_by)
260+
246261
if non_nested_expressions:
247262
super().alter_table(non_nested_expressions)
248263

@@ -847,25 +862,55 @@ def _get_data_objects(
847862
# resort to using SQL instead.
848863
schema = to_schema(schema_name)
849864
catalog = schema.catalog or self.default_catalog
850-
query = exp.select(
851-
exp.column("table_catalog").as_("catalog"),
852-
exp.column("table_name").as_("name"),
853-
exp.column("table_schema").as_("schema_name"),
854-
exp.case()
855-
.when(exp.column("table_type").eq("BASE TABLE"), exp.Literal.string("TABLE"))
856-
.when(exp.column("table_type").eq("CLONE"), exp.Literal.string("TABLE"))
857-
.when(exp.column("table_type").eq("EXTERNAL"), exp.Literal.string("TABLE"))
858-
.when(exp.column("table_type").eq("SNAPSHOT"), exp.Literal.string("TABLE"))
859-
.when(exp.column("table_type").eq("VIEW"), exp.Literal.string("VIEW"))
860-
.when(
861-
exp.column("table_type").eq("MATERIALIZED VIEW"),
862-
exp.Literal.string("MATERIALIZED_VIEW"),
865+
query = (
866+
exp.select(
867+
exp.column("table_catalog").as_("catalog"),
868+
exp.column("table_name").as_("name"),
869+
exp.column("table_schema").as_("schema_name"),
870+
exp.case()
871+
.when(exp.column("table_type").eq("BASE TABLE"), exp.Literal.string("TABLE"))
872+
.when(exp.column("table_type").eq("CLONE"), exp.Literal.string("TABLE"))
873+
.when(exp.column("table_type").eq("EXTERNAL"), exp.Literal.string("TABLE"))
874+
.when(exp.column("table_type").eq("SNAPSHOT"), exp.Literal.string("TABLE"))
875+
.when(exp.column("table_type").eq("VIEW"), exp.Literal.string("VIEW"))
876+
.when(
877+
exp.column("table_type").eq("MATERIALIZED VIEW"),
878+
exp.Literal.string("MATERIALIZED_VIEW"),
879+
)
880+
.else_(exp.column("table_type"))
881+
.as_("type"),
882+
exp.column("clustering_key", "ci").as_("clustering_key"),
883+
)
884+
.with_(
885+
"clustering_info",
886+
as_=exp.select(
887+
exp.column("table_catalog"),
888+
exp.column("table_schema"),
889+
exp.column("table_name"),
890+
parse_one(
891+
"string_agg(column_name order by clustering_ordinal_position)",
892+
dialect=self.dialect,
893+
).as_("clustering_key"),
894+
)
895+
.from_(
896+
exp.to_table(
897+
f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.COLUMNS",
898+
dialect=self.dialect,
899+
)
900+
)
901+
.where(exp.column("clustering_ordinal_position").is_(exp.not_(exp.null())))
902+
.group_by("1", "2", "3"),
863903
)
864-
.else_(exp.column("table_type"))
865-
.as_("type"),
866-
).from_(
867-
exp.to_table(
868-
f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.TABLES", dialect=self.dialect
904+
.from_(
905+
exp.to_table(
906+
f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.TABLES", dialect=self.dialect
907+
)
908+
)
909+
.join(
910+
"clustering_info",
911+
using=["table_catalog", "table_schema", "table_name"],
912+
join_type="left",
913+
join_alias="ci",
869914
)
870915
)
871916
if object_names:
@@ -886,10 +931,41 @@ def _get_data_objects(
886931
schema=row.schema_name, # type: ignore
887932
name=row.name, # type: ignore
888933
type=DataObjectType.from_str(row.type), # type: ignore
934+
clustering_key=f"({row.clustering_key})" if row.clustering_key else None, # type: ignore
889935
)
890936
for row in df.itertuples()
891937
]
892938

939+
def _change_clustering_key_expr(
940+
self, table: exp.Table, cluster_by: t.List[exp.Expression]
941+
) -> exp.Alter:
942+
expr = super()._change_clustering_key_expr(table=table, cluster_by=cluster_by)
943+
expr.meta[_CLUSTERING_META_KEY] = (table, cluster_by)
944+
return expr
945+
946+
def _drop_clustering_key_expr(self, table: exp.Table) -> exp.Alter:
947+
expr = super()._drop_clustering_key_expr(table=table)
948+
expr.meta[_CLUSTERING_META_KEY] = (table, None)
949+
return expr
950+
951+
def _update_clustering_key(
952+
self, table_name: TableName, cluster_by: t.Optional[t.List[exp.Expression]]
953+
) -> None:
954+
cluster_by = cluster_by or []
955+
bq_table = self._get_table(table_name)
956+
957+
rendered_columns = [c.sql(dialect=self.dialect) for c in cluster_by]
958+
bq_table.clustering_fields = (
959+
rendered_columns or None
960+
) # causes a drop of the key if cluster_by is empty or None
961+
962+
self._db_call(self.client.update_table, table=bq_table, fields=["clustering_fields"])
963+
964+
if cluster_by:
965+
# BigQuery only applies new clustering going forward, so this rewrites the columns to apply the new clustering to historical data
966+
# ref: https://cloud.google.com/bigquery/docs/creating-clustered-tables#modifying-cluster-spec
967+
self.execute(exp.update(table_name, {c: c for c in cluster_by}, where=exp.true()))
968+
893969
@property
894970
def _query_data(self) -> t.Any:
895971
return self._connection_pool.get_attribute("query_data")
@@ -971,7 +1047,7 @@ def select_partitions_expr(
9711047
"""Generates a SQL expression that aggregates partition values for a table.
9721048
9731049
Args:
974-
schema: The schema (BigQueyr dataset) of the table.
1050+
schema: The schema (BigQuery dataset) of the table.
9751051
table_name: The name of the table.
9761052
data_type: The data type of the partition column.
9771053
granularity: The granularity of the partition. Supported values are: 'day', 'month', 'year' and 'hour'.

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import logging
44
import typing as t
55

6-
from sqlglot import exp
6+
from sqlglot import exp, parse_one
7+
from sqlglot.helper import seq_get
78

89
from sqlmesh.core.engine_adapter.base import EngineAdapter
910
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery
@@ -337,3 +338,61 @@ def _build_clustered_by_exp(
337338
**kwargs: t.Any,
338339
) -> t.Optional[exp.Cluster]:
339340
return exp.Cluster(expressions=[exp.column(col) for col in clustered_by])
341+
342+
def _parse_clustering_key(self, clustering_key: t.Optional[str]) -> t.List[exp.Expression]:
343+
if not clustering_key:
344+
return []
345+
346+
# Note: Assumes `clustering_key` as a string like:
347+
# - "(col_a)"
348+
# - "(col_a, col_b)"
349+
# - "func(col_a, transform(col_b))"
350+
parsed_cluster_key = parse_one(clustering_key, dialect=self.dialect)
351+
352+
return parsed_cluster_key.expressions or [parsed_cluster_key.this]
353+
354+
def get_alter_expressions(
355+
self, current_table_name: TableName, target_table_name: TableName
356+
) -> t.List[exp.Alter]:
357+
expressions = super().get_alter_expressions(current_table_name, target_table_name)
358+
359+
# check for a change in clustering
360+
current_table = exp.to_table(current_table_name)
361+
target_table = exp.to_table(target_table_name)
362+
363+
current_table_info = seq_get(
364+
self.get_data_objects(current_table.db, {current_table.name}), 0
365+
)
366+
target_table_info = seq_get(self.get_data_objects(target_table.db, {target_table.name}), 0)
367+
368+
if current_table_info and target_table_info:
369+
if target_table_info.is_clustered:
370+
if target_table_info.clustering_key and (
371+
current_table_info.clustering_key != target_table_info.clustering_key
372+
):
373+
expressions.append(
374+
self._change_clustering_key_expr(
375+
current_table,
376+
self._parse_clustering_key(target_table_info.clustering_key),
377+
)
378+
)
379+
elif current_table_info.is_clustered:
380+
expressions.append(self._drop_clustering_key_expr(current_table))
381+
382+
return expressions
383+
384+
def _change_clustering_key_expr(
385+
self, table: exp.Table, cluster_by: t.List[exp.Expression]
386+
) -> exp.Alter:
387+
return exp.Alter(
388+
this=table,
389+
kind="TABLE",
390+
actions=[exp.Cluster(expressions=cluster_by)],
391+
)
392+
393+
def _drop_clustering_key_expr(self, table: exp.Table) -> exp.Alter:
394+
return exp.Alter(
395+
this=table,
396+
kind="TABLE",
397+
actions=[exp.Command(this="DROP", expression="CLUSTERING KEY")],
398+
)

sqlmesh/core/engine_adapter/shared.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ class DataObject(PydanticModel):
164164
name: str
165165
type: DataObjectType
166166

167+
# for type=DataObjectType.Table, only if the DB supports it
168+
clustering_key: t.Optional[str] = None
169+
170+
@property
171+
def is_clustered(self) -> bool:
172+
return bool(self.clustering_key)
173+
167174

168175
class CatalogSupport(Enum):
169176
UNSUPPORTED = 1

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import pandas as pd
88
from pandas.api.types import is_datetime64_any_dtype # type: ignore
9-
from sqlglot import exp, parse_one
10-
from sqlglot.helper import seq_get
9+
from sqlglot import exp
1110
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1211
from sqlglot.optimizer.qualify_columns import quote_identifiers
1312

@@ -33,14 +32,6 @@
3332
from sqlmesh.core.node import IntervalUnit
3433

3534

36-
class SnowflakeDataObject(DataObject):
37-
clustering_key: t.Optional[str] = None
38-
39-
@property
40-
def is_clustered(self) -> bool:
41-
return bool(self.clustering_key)
42-
43-
4435
@set_catalog(
4536
override_mapping={
4637
"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG,
@@ -348,7 +339,7 @@ def _get_data_objects(
348339
if df.empty:
349340
return []
350341
return [
351-
SnowflakeDataObject(
342+
DataObject(
352343
catalog=row.catalog, # type: ignore
353344
schema=row.schema_name, # type: ignore
354345
name=row.name, # type: ignore
@@ -433,50 +424,3 @@ def _create_column_comments(
433424
f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.",
434425
exc_info=True,
435426
)
436-
437-
def get_alter_expressions(
438-
self, current_table_name: TableName, target_table_name: TableName
439-
) -> t.List[exp.Alter]:
440-
schema_expressions = super().get_alter_expressions(current_table_name, target_table_name)
441-
additional_expressions = []
442-
443-
# check for a change in clustering
444-
current_table = exp.to_table(current_table_name)
445-
target_table = exp.to_table(target_table_name)
446-
447-
current_table_info = t.cast(
448-
SnowflakeDataObject,
449-
seq_get(self.get_data_objects(current_table.db, {current_table.name}), 0),
450-
)
451-
target_table_info = t.cast(
452-
SnowflakeDataObject,
453-
seq_get(self.get_data_objects(target_table.db, {target_table.name}), 0),
454-
)
455-
456-
if current_table_info and target_table_info:
457-
if target_table_info.is_clustered:
458-
if target_table_info.clustering_key and (
459-
current_table_info.clustering_key != target_table_info.clustering_key
460-
):
461-
# Note: If you create a table with eg `CLUSTER BY (c2, c1)` and read the info back from information_schema,
462-
# it gets returned as a string like "LINEAR(c2, c1)" which we need to parse back into a list of columns
463-
parsed_cluster_key = parse_one(
464-
target_table_info.clustering_key, dialect=self.dialect
465-
)
466-
additional_expressions.append(
467-
exp.Alter(
468-
this=current_table,
469-
kind="TABLE",
470-
actions=[exp.Cluster(expressions=parsed_cluster_key.expressions)],
471-
)
472-
)
473-
elif current_table_info.is_clustered:
474-
additional_expressions.append(
475-
exp.Alter(
476-
this=current_table,
477-
kind="TABLE",
478-
actions=[exp.Command(this="DROP", expression="CLUSTERING KEY")],
479-
)
480-
)
481-
482-
return schema_expressions + additional_expressions

0 commit comments

Comments
 (0)