Skip to content

Commit 75f825e

Browse files
authored
Fix!: Avoid using rendered query when computing the data hash (#5256)
1 parent 454e942 commit 75f825e

19 files changed

+533
-191
lines changed

sqlmesh/core/audit/definition.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
bool_validator,
1616
default_catalog_validator,
1717
depends_on_validator,
18-
expression_validator,
1918
sort_python_env,
2019
sorted_python_env_payloads,
2120
)
22-
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple
21+
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql
2322
from sqlmesh.core.node import _Node
2423
from sqlmesh.core.renderer import QueryRenderer
2524
from sqlmesh.utils.date import TimeLike
@@ -67,15 +66,26 @@ class AuditMixin(AuditCommonMetaMixin):
6766
jinja_macros: A registry of jinja macros to use when rendering the audit query.
6867
"""
6968

70-
query: t.Union[exp.Query, d.JinjaQuery]
69+
query_: ParsableSql
7170
defaults: t.Dict[str, exp.Expression]
72-
expressions_: t.Optional[t.List[exp.Expression]]
71+
expressions_: t.Optional[t.List[ParsableSql]]
7372
jinja_macros: JinjaMacroRegistry
7473
formatting: t.Optional[bool]
7574

75+
@property
76+
def query(self) -> t.Union[exp.Query, d.JinjaQuery]:
77+
return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect))
78+
7679
@property
7780
def expressions(self) -> t.List[exp.Expression]:
78-
return self.expressions_ or []
81+
if not self.expressions_:
82+
return []
83+
result = []
84+
for e in self.expressions_:
85+
parsed = e.parse(self.dialect)
86+
if not isinstance(parsed, exp.Semicolon):
87+
result.append(parsed)
88+
return result
7989

8090
@property
8191
def macro_definitions(self) -> t.List[d.MacroDef]:
@@ -122,16 +132,16 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True):
122132
skip: bool = False
123133
blocking: bool = True
124134
standalone: t.Literal[False] = False
125-
query: t.Union[exp.Query, d.JinjaQuery]
135+
query_: ParsableSql = Field(alias="query")
126136
defaults: t.Dict[str, exp.Expression] = {}
127-
expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions")
137+
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
128138
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
129139
formatting: t.Optional[bool] = Field(default=None, exclude=True)
130140

131141
_path: t.Optional[Path] = None
132142

133143
# Validators
134-
_query_validator = expression_validator
144+
_query_validator = ParsableSql.validator()
135145
_bool_validator = bool_validator
136146
_string_validator = audit_string_validator
137147
_map_validator = audit_map_validator
@@ -153,9 +163,9 @@ class StandaloneAudit(_Node, AuditMixin):
153163
skip: bool = False
154164
blocking: bool = False
155165
standalone: t.Literal[True] = True
156-
query: t.Union[exp.Query, d.JinjaQuery]
166+
query_: ParsableSql = Field(alias="query")
157167
defaults: t.Dict[str, exp.Expression] = {}
158-
expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions")
168+
expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions")
159169
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
160170
default_catalog: t.Optional[str] = None
161171
depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on")
@@ -165,7 +175,7 @@ class StandaloneAudit(_Node, AuditMixin):
165175
source_type: t.Literal["audit"] = "audit"
166176

167177
# Validators
168-
_query_validator = expression_validator
178+
_query_validator = ParsableSql.validator()
169179
_bool_validator = bool_validator
170180
_string_validator = audit_string_validator
171181
_map_validator = audit_map_validator
@@ -276,8 +286,8 @@ def metadata_hash(self) -> str:
276286
self.cron_tz.key if self.cron_tz else None,
277287
]
278288

279-
query = self.render_audit_query() or self.query
280-
data.append(gen(query))
289+
data.append(self.query_.sql)
290+
data.extend([e.sql for e in self.expressions_ or []])
281291
self._metadata_hash = hash_data(data)
282292
return self._metadata_hash
283293

@@ -461,11 +471,17 @@ def load_audit(
461471
if project is not None:
462472
extra_kwargs["project"] = project
463473

464-
dialect = meta_fields.pop("dialect", dialect)
474+
dialect = meta_fields.pop("dialect", dialect) or ""
475+
476+
parsable_query = ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True)
477+
parsable_statements = [
478+
ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=True) for s in statements
479+
]
480+
465481
try:
466482
audit = audit_class(
467-
query=query,
468-
expressions=statements,
483+
query=parsable_query,
484+
expressions=parsable_statements,
469485
dialect=dialect,
470486
**extra_kwargs,
471487
**meta_fields,

sqlmesh/core/context_diff.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def directly_modified(self, name: str) -> bool:
435435
return False
436436

437437
current, previous = self.modified_snapshots[name]
438-
return current.fingerprint.data_hash != previous.fingerprint.data_hash
438+
return current.is_directly_modified(previous)
439439

440440
def indirectly_modified(self, name: str) -> bool:
441441
"""Returns whether or not a node was indirectly modified in this context.
@@ -451,10 +451,7 @@ def indirectly_modified(self, name: str) -> bool:
451451
return False
452452

453453
current, previous = self.modified_snapshots[name]
454-
return (
455-
current.fingerprint.data_hash == previous.fingerprint.data_hash
456-
and current.fingerprint.parent_data_hash != previous.fingerprint.parent_data_hash
457-
)
454+
return current.is_indirectly_modified(previous)
458455

459456
def metadata_updated(self, name: str) -> bool:
460457
"""Returns whether or not the given node's metadata has been updated.
@@ -470,7 +467,7 @@ def metadata_updated(self, name: str) -> bool:
470467
return False
471468

472469
current, previous = self.modified_snapshots[name]
473-
return current.fingerprint.metadata_hash != previous.fingerprint.metadata_hash
470+
return current.is_metadata_updated(previous)
474471

475472
def text_diff(self, name: str) -> str:
476473
"""Finds the difference of a node between the current and remote environment.

sqlmesh/core/model/common.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
prepare_env,
2222
serialize_env,
2323
)
24-
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
24+
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect
2525

2626
if t.TYPE_CHECKING:
2727
from sqlglot.dialects.dialect import DialectType
@@ -616,11 +616,6 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any:
616616

617617

618618
expression_validator: t.Callable = field_validator(
619-
"query",
620-
"expressions_",
621-
"pre_statements_",
622-
"post_statements_",
623-
"on_virtual_update_",
624619
"unique_key",
625620
mode="before",
626621
check_fields=False,
@@ -663,3 +658,65 @@ def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any:
663658
mode="before",
664659
check_fields=False,
665660
)(depends_on)
661+
662+
663+
class ParsableSql(PydanticModel):
664+
sql: str
665+
666+
_parsed: t.Optional[exp.Expression] = None
667+
_parsed_dialect: t.Optional[str] = None
668+
669+
def parse(self, dialect: str) -> exp.Expression:
670+
if self._parsed is None or self._parsed_dialect != dialect:
671+
self._parsed = d.parse_one(self.sql, dialect=dialect)
672+
self._parsed_dialect = dialect
673+
return self._parsed
674+
675+
@classmethod
676+
def from_parsed_expression(
677+
cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False
678+
) -> ParsableSql:
679+
sql = (
680+
parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect)
681+
if use_meta_sql
682+
else parsed_expression.sql(dialect=dialect)
683+
)
684+
result = cls(sql=sql)
685+
result._parsed = parsed_expression
686+
result._parsed_dialect = dialect
687+
return result
688+
689+
@classmethod
690+
def validator(cls) -> classmethod:
691+
def _validate_parsable_sql(
692+
v: t.Any, info: ValidationInfo
693+
) -> t.Optional[t.Union[ParsableSql, t.List[ParsableSql]]]:
694+
if v is None:
695+
return v
696+
if isinstance(v, str):
697+
return ParsableSql(sql=v)
698+
if isinstance(v, exp.Expression):
699+
return ParsableSql.from_parsed_expression(
700+
v, get_dialect(info.data), use_meta_sql=False
701+
)
702+
if isinstance(v, list):
703+
dialect = get_dialect(info.data)
704+
return [
705+
ParsableSql(sql=s)
706+
if isinstance(s, str)
707+
else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False)
708+
if isinstance(s, exp.Expression)
709+
else ParsableSql.parse_obj(s)
710+
for s in v
711+
]
712+
return ParsableSql.parse_obj(v)
713+
714+
return field_validator(
715+
"query_",
716+
"expressions_",
717+
"pre_statements_",
718+
"post_statements_",
719+
"on_virtual_update_",
720+
mode="before",
721+
check_fields=False,
722+
)(_validate_parsable_sql)

0 commit comments

Comments
 (0)