|
18 | 18 |
|
19 | 19 | from sqlglot import Dialect, exp |
20 | 20 | from sqlglot.errors import ErrorLevel |
21 | | -from sqlglot.helper import ensure_list |
| 21 | +from sqlglot.helper import ensure_list, seq_get |
22 | 22 | from sqlglot.optimizer.qualify_columns import quote_identifiers |
23 | 23 |
|
24 | 24 | from sqlmesh.core.dialect import ( |
@@ -551,11 +551,13 @@ def replace_query( |
551 | 551 | target_table, |
552 | 552 | source_queries, |
553 | 553 | target_columns_to_types, |
| 554 | + **kwargs, |
554 | 555 | ) |
555 | 556 | return self._insert_overwrite_by_condition( |
556 | 557 | target_table, |
557 | 558 | source_queries, |
558 | 559 | target_columns_to_types, |
| 560 | + **kwargs, |
559 | 561 | ) |
560 | 562 |
|
561 | 563 | def create_index( |
@@ -1614,7 +1616,7 @@ def _insert_overwrite_by_time_partition( |
1614 | 1616 | **kwargs: t.Any, |
1615 | 1617 | ) -> None: |
1616 | 1618 | return self._insert_overwrite_by_condition( |
1617 | | - table_name, source_queries, target_columns_to_types, where |
| 1619 | + table_name, source_queries, target_columns_to_types, where, **kwargs |
1618 | 1620 | ) |
1619 | 1621 |
|
1620 | 1622 | def _values_to_sql( |
@@ -1772,7 +1774,7 @@ def scd_type_2_by_column( |
1772 | 1774 | valid_from_col: exp.Column, |
1773 | 1775 | valid_to_col: exp.Column, |
1774 | 1776 | execution_time: t.Union[TimeLike, exp.Column], |
1775 | | - check_columns: t.Union[exp.Star, t.Sequence[exp.Column]], |
| 1777 | + check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]], |
1776 | 1778 | invalidate_hard_deletes: bool = True, |
1777 | 1779 | execution_time_as_valid_from: bool = False, |
1778 | 1780 | target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, |
@@ -1810,7 +1812,7 @@ def _scd_type_2( |
1810 | 1812 | execution_time: t.Union[TimeLike, exp.Column], |
1811 | 1813 | invalidate_hard_deletes: bool = True, |
1812 | 1814 | updated_at_col: t.Optional[exp.Column] = None, |
1813 | | - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, |
| 1815 | + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, |
1814 | 1816 | updated_at_as_valid_from: bool = False, |
1815 | 1817 | execution_time_as_valid_from: bool = False, |
1816 | 1818 | target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, |
@@ -1885,8 +1887,10 @@ def remove_managed_columns( |
1885 | 1887 | # they are equal or not, the extra check is not a problem and we gain simplified logic here. |
1886 | 1888 | # If we want to change this, then we just need to check the expressions in unique_key and pull out the |
1887 | 1889 | # column names and then remove them from the unmanaged_columns |
1888 | | - if check_columns and check_columns == exp.Star(): |
1889 | | - check_columns = [exp.column(col) for col in unmanaged_columns_to_types] |
| 1890 | + if check_columns: |
| 1891 | + # Handle both Star directly and [Star()] (which can happen during serialization/deserialization) |
| 1892 | + if isinstance(seq_get(ensure_list(check_columns), 0), exp.Star): |
| 1893 | + check_columns = [exp.column(col) for col in unmanaged_columns_to_types] |
1890 | 1894 | execution_ts = ( |
1891 | 1895 | exp.cast(execution_time, time_data_type, dialect=self.dialect) |
1892 | 1896 | if isinstance(execution_time, exp.Column) |
@@ -1923,7 +1927,8 @@ def remove_managed_columns( |
1923 | 1927 | col_qualified.set("table", exp.to_identifier("joined")) |
1924 | 1928 |
|
1925 | 1929 | t_col = col_qualified.copy() |
1926 | | - t_col.this.set("this", f"t_{col.name}") |
| 1930 | + for column in t_col.find_all(exp.Column): |
| 1931 | + column.this.set("this", f"t_{column.name}") |
1927 | 1932 |
|
1928 | 1933 | row_check_conditions.extend( |
1929 | 1934 | [ |
|
0 commit comments