|
2 | 2 |
|
3 | 3 | import typing as t |
4 | 4 | from enum import Enum |
| 5 | +from typing_extensions import Self |
5 | 6 |
|
6 | 7 | from pydantic import Field |
7 | 8 | from sqlglot import exp |
@@ -240,43 +241,7 @@ class TimeColumn(PydanticModel): |
240 | 241 | @classmethod |
241 | 242 | def validator(cls) -> classmethod: |
242 | 243 | def _time_column_validator(v: t.Any, info: ValidationInfo) -> TimeColumn: |
243 | | - dialect = get_dialect(info.data) |
244 | | - |
245 | | - if isinstance(v, exp.Tuple): |
246 | | - column_expr = v.expressions[0] |
247 | | - column = ( |
248 | | - exp.column(column_expr) |
249 | | - if isinstance(column_expr, exp.Identifier) |
250 | | - else column_expr |
251 | | - ) |
252 | | - format = v.expressions[1].name if len(v.expressions) > 1 else None |
253 | | - elif isinstance(v, exp.Expression): |
254 | | - column = exp.column(v) if isinstance(v, exp.Identifier) else v |
255 | | - format = None |
256 | | - elif isinstance(v, str): |
257 | | - column = d.parse_one(v, dialect=dialect) |
258 | | - column.meta.pop("sql") |
259 | | - format = None |
260 | | - elif isinstance(v, dict): |
261 | | - column_raw = v["column"] |
262 | | - column = ( |
263 | | - d.parse_one(column_raw, dialect=dialect) |
264 | | - if isinstance(column_raw, str) |
265 | | - else column_raw |
266 | | - ) |
267 | | - format = v.get("format") |
268 | | - elif isinstance(v, TimeColumn): |
269 | | - column = v.column |
270 | | - format = v.format |
271 | | - else: |
272 | | - raise ConfigError(f"Invalid time_column: '{v}'.") |
273 | | - |
274 | | - column = quote_identifiers( |
275 | | - normalize_identifiers(column, dialect=dialect), dialect=dialect |
276 | | - ) |
277 | | - column.meta["dialect"] = dialect |
278 | | - |
279 | | - return TimeColumn(column=column, format=format) |
| 244 | + return TimeColumn.create(v, get_dialect(info.data)) |
280 | 245 |
|
281 | 246 | return field_validator("time_column", mode="before")(_time_column_validator) |
282 | 247 |
|
@@ -314,6 +279,40 @@ def to_expression(self, dialect: str) -> exp.Expression: |
314 | 279 | def to_property(self, dialect: str = "") -> exp.Property: |
315 | 280 | return exp.Property(this="time_column", value=self.to_expression(dialect)) |
316 | 281 |
|
| 282 | + @classmethod |
| 283 | + def create(cls, v: t.Any, dialect: str) -> Self: |
| 284 | + if isinstance(v, exp.Tuple): |
| 285 | + column_expr = v.expressions[0] |
| 286 | + column = ( |
| 287 | + exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr |
| 288 | + ) |
| 289 | + format = v.expressions[1].name if len(v.expressions) > 1 else None |
| 290 | + elif isinstance(v, exp.Expression): |
| 291 | + column = exp.column(v) if isinstance(v, exp.Identifier) else v |
| 292 | + format = None |
| 293 | + elif isinstance(v, str): |
| 294 | + column = d.parse_one(v, dialect=dialect) |
| 295 | + column.meta.pop("sql") |
| 296 | + format = None |
| 297 | + elif isinstance(v, dict): |
| 298 | + column_raw = v["column"] |
| 299 | + column = ( |
| 300 | + d.parse_one(column_raw, dialect=dialect) |
| 301 | + if isinstance(column_raw, str) |
| 302 | + else column_raw |
| 303 | + ) |
| 304 | + format = v.get("format") |
| 305 | + elif isinstance(v, TimeColumn): |
| 306 | + column = v.column |
| 307 | + format = v.format |
| 308 | + else: |
| 309 | + raise ConfigError(f"Invalid time_column: '{v}'.") |
| 310 | + |
| 311 | + column = quote_identifiers(normalize_identifiers(column, dialect=dialect), dialect=dialect) |
| 312 | + column.meta["dialect"] = dialect |
| 313 | + |
| 314 | + return cls(column=column, format=format) |
| 315 | + |
317 | 316 |
|
318 | 317 | def _kind_dialect_validator(cls: t.Type, v: t.Optional[str]) -> str: |
319 | 318 | if v is None: |
@@ -836,17 +835,16 @@ class CustomKind(_ModelKind): |
836 | 835 | auto_restatement_cron: t.Optional[SQLGlotCron] = None |
837 | 836 | auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None |
838 | 837 |
|
| 838 | + # so that CustomKind subclasses know the dialect when validating / normalizing / interpreting values in `materialization_properties` |
| 839 | + dialect: str = Field(exclude=True) |
| 840 | + |
839 | 841 | _properties_validator = properties_validator |
840 | 842 |
|
841 | 843 | @field_validator("materialization", mode="before") |
842 | 844 | @classmethod |
843 | 845 | def _validate_materialization(cls, v: t.Any) -> str: |
844 | | - from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type |
845 | | - |
846 | | - materialization = validate_string(v) |
847 | | - # The below call fails if a materialization with the given name doesn't exist. |
848 | | - get_custom_materialization_type(materialization) |
849 | | - return materialization |
| 846 | + # note: create_model_kind() validates the custom materialization class |
| 847 | + return validate_string(v) |
850 | 848 |
|
851 | 849 | @property |
852 | 850 | def materialization_properties(self) -> CustomMaterializationProperties: |
@@ -985,11 +983,15 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M |
985 | 983 | "The 'materialization' property is required for models of the CUSTOM kind" |
986 | 984 | ) |
987 | 985 |
|
988 | | - actual_kind_type, _ = get_custom_materialization_type( |
989 | | - validate_string(props.get("materialization")) |
990 | | - ) |
991 | | - |
992 | | - return actual_kind_type(**props) |
| 986 | + # The below call will print a warning if a materialization with the given name doesn't exist |
| 987 | + # we dont want to throw an error here because we still want Models with a CustomKind to be able |
| 988 | + # to be serialized / deserialized in contexts where the custom materialization class may not be available, |
| 989 | + # such as in HTTP request handlers |
| 990 | + if custom_materialization := get_custom_materialization_type( |
| 991 | + validate_string(props.get("materialization")), raise_errors=False |
| 992 | + ): |
| 993 | + actual_kind_type, _ = custom_materialization |
| 994 | + return actual_kind_type(**props) |
993 | 995 |
|
994 | 996 | return kind_type(**props) |
995 | 997 |
|
|
0 commit comments