|
18 | 18 |
|
19 | 19 | from sqlglot import Dialect, exp |
20 | 20 | from sqlglot.errors import ErrorLevel |
21 | | -from sqlglot.helper import ensure_list |
| 21 | +from sqlglot.helper import ensure_list, seq_get |
22 | 22 | from sqlglot.optimizer.qualify_columns import quote_identifiers |
23 | 23 |
|
24 | 24 | from sqlmesh.core.dialect import ( |
@@ -1772,7 +1772,7 @@ def scd_type_2_by_column( |
1772 | 1772 | valid_from_col: exp.Column, |
1773 | 1773 | valid_to_col: exp.Column, |
1774 | 1774 | execution_time: t.Union[TimeLike, exp.Column], |
1775 | | - check_columns: t.Union[exp.Star, t.Sequence[exp.Column]], |
| 1775 | + check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]], |
1776 | 1776 | invalidate_hard_deletes: bool = True, |
1777 | 1777 | execution_time_as_valid_from: bool = False, |
1778 | 1778 | target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, |
@@ -1810,7 +1810,7 @@ def _scd_type_2( |
1810 | 1810 | execution_time: t.Union[TimeLike, exp.Column], |
1811 | 1811 | invalidate_hard_deletes: bool = True, |
1812 | 1812 | updated_at_col: t.Optional[exp.Column] = None, |
1813 | | - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, |
| 1813 | + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, |
1814 | 1814 | updated_at_as_valid_from: bool = False, |
1815 | 1815 | execution_time_as_valid_from: bool = False, |
1816 | 1816 | target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, |
@@ -1885,8 +1885,10 @@ def remove_managed_columns( |
1885 | 1885 | # they are equal or not, the extra check is not a problem and we gain simplified logic here. |
1886 | 1886 | # If we want to change this, then we just need to check the expressions in unique_key and pull out the |
1887 | 1887 | # column names and then remove them from the unmanaged_columns |
1888 | | - if check_columns and check_columns == exp.Star(): |
1889 | | - check_columns = [exp.column(col) for col in unmanaged_columns_to_types] |
| 1888 | + if check_columns: |
| 1889 | + # Handle both Star directly and [Star()] (which can happen during serialization/deserialization) |
| 1890 | + if isinstance(seq_get(ensure_list(check_columns), 0), exp.Star): |
| 1891 | + check_columns = [exp.column(col) for col in unmanaged_columns_to_types] |
1890 | 1892 | execution_ts = ( |
1891 | 1893 | exp.cast(execution_time, time_data_type, dialect=self.dialect) |
1892 | 1894 | if isinstance(execution_time, exp.Column) |
@@ -1923,7 +1925,8 @@ def remove_managed_columns( |
1923 | 1925 | col_qualified.set("table", exp.to_identifier("joined")) |
1924 | 1926 |
|
1925 | 1927 | t_col = col_qualified.copy() |
1926 | | - t_col.this.set("this", f"t_{col.name}") |
| 1928 | + for column in t_col.find_all(exp.Column): |
| 1929 | + column.this.set("this", f"t_{column.name}") |
1927 | 1930 |
|
1928 | 1931 | row_check_conditions.extend( |
1929 | 1932 | [ |
|
0 commit comments