Skip to content

Commit 8e786fd

Browse files
Feat: Support enabling audits in model defaults (#2947)
1 parent 30071cc commit 8e786fd

File tree

11 files changed

+126
-30
lines changed

11 files changed

+126
-30
lines changed

docs/concepts/audits.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ SELECT * FROM @this_model
8888
WHERE @column >= @threshold;
8989
```
9090

91+
Alternatively, you can apply specific audits globally by including them in the model defaults configuration:
92+
93+
```sql linenums="1"
94+
model_defaults:
95+
audits:
96+
- assert_positive_order_ids
97+
- does_not_exceed_threshold(column := id, threshold := 1000)
98+
```
99+
91100
### Naming
92101
We recommended avoiding SQL keywords when naming audit parameters. Quote any audit argument that is also a SQL keyword.
93102

docs/concepts/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ You create audits by writing SQL queries that should return 0 rows. For example,
6161

6262
Audits are flexible — they can be tied to a specific model's contents, or you can use [macros](./macros/overview.md) to create audits that are usable by multiple models. SQLMesh also includes pre-made audits for common use cases, such as detecting NULL or duplicated values.
6363

64-
You specify which audits should run for a model by including them in the model's metadata properties.
64+
You specify which audits should run for a model by including them in the model's metadata properties. To apply them globally across your project, include them in the model defaults configuration.
6565

6666
SQLMesh automatically runs audits when you apply a `plan` to an environment, or you can run them on demand with the [`audit` command](../reference/cli.md#audit).
6767

docs/reference/model_configuration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ The SQLMesh project-level `model_defaults` key supports the following options, d
5151
- storage_format
5252
- session_properties (on per key basis)
5353
- on_destructive_change (described [below](#incremental-models))
54+
- audits (described [here](../concepts/audits.md#generic-audits))
5455

5556

5657
### Model Naming

sqlmesh/core/config/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import typing as t
44

5+
from sqlmesh.core.dialect import parse_one, extract_audit
56
from sqlmesh.core.config.base import BaseConfig
67
from sqlmesh.core.model.kind import (
78
ModelKind,
@@ -10,6 +11,8 @@
1011
on_destructive_change_validator,
1112
)
1213
from sqlmesh.utils.date import TimeLike
14+
from sqlmesh.core.model.meta import AuditReference
15+
from sqlmesh.utils.pydantic import field_validator
1316

1417

1518
class ModelDefaultsConfig(BaseConfig):
@@ -27,6 +30,7 @@ class ModelDefaultsConfig(BaseConfig):
2730
storage_format: The storage format used to store the physical table, only applicable in certain engines.
2831
(eg. 'parquet')
2932
on_destructive_change: What should happen when a forward-only model requires a destructive schema change.
33+
audits: The audits to be applied globally to all models in the project.
3034
"""
3135

3236
kind: t.Optional[ModelKind] = None
@@ -37,6 +41,14 @@ class ModelDefaultsConfig(BaseConfig):
3741
storage_format: t.Optional[str] = None
3842
on_destructive_change: t.Optional[OnDestructiveChange] = None
3943
session_properties: t.Optional[t.Dict[str, t.Any]] = None
44+
audits: t.Optional[t.List[AuditReference]] = None
4045

4146
_model_kind_validator = model_kind_validator
4247
_on_destructive_change_validator = on_destructive_change_validator
48+
49+
@field_validator("audits", mode="before")
50+
def _audits_validator(cls, v: t.Any) -> t.Any:
51+
if isinstance(v, list):
52+
return [extract_audit(parse_one(audit)) for audit in v]
53+
54+
return v

sqlmesh/core/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,7 @@ def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]:
19071907
audits=audits,
19081908
cache=fingerprint_cache,
19091909
ttl=ttl,
1910+
config=self.config_for_node(node),
19101911
)
19111912
snapshots[snapshot.name] = snapshot
19121913
return snapshots

sqlmesh/core/dialect.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sqlglot.tokens import Token
2222

2323
from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE
24-
from sqlmesh.utils.errors import SQLMeshError
24+
from sqlmesh.utils.errors import SQLMeshError, ConfigError
2525
from sqlmesh.utils.pandas import columns_to_types_from_df
2626

2727
if t.TYPE_CHECKING:
@@ -1099,3 +1099,24 @@ def interpret_key_value_pairs(
10991099
e: exp.Tuple,
11001100
) -> t.Dict[str, exp.Expression | str | int | float | bool]:
11011101
return {i.this.name: interpret_expression(i.expression) for i in e.expressions}
1102+
1103+
1104+
def extract_audit(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
1105+
kwargs = {}
1106+
1107+
if isinstance(v, exp.Anonymous):
1108+
func = v.name
1109+
args = v.expressions
1110+
elif isinstance(v, exp.Func):
1111+
func = v.sql_name()
1112+
args = list(v.args.values())
1113+
else:
1114+
return v.name.lower(), {}
1115+
1116+
for arg in args:
1117+
if not isinstance(arg, (exp.PropertyEQ, exp.EQ)):
1118+
raise ConfigError(
1119+
f"Function '{func}' must be called with key-value arguments like {func}(arg := value)."
1120+
)
1121+
kwargs[arg.left.name.lower()] = arg.right
1122+
return func.lower(), kwargs

sqlmesh/core/model/definition.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate, macro
2626
from sqlmesh.core.model.common import expression_validator
2727
from sqlmesh.core.model.kind import ModelKindName, SeedKind, ModelKind, FullKind, create_model_kind
28-
from sqlmesh.core.model.meta import ModelMeta
28+
from sqlmesh.core.model.meta import ModelMeta, AuditReference
2929
from sqlmesh.core.model.seed import CsvSeedReader, Seed, create_seed
3030
from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer
3131
from sqlmesh.utils import columns_to_types_all_known, str_to_bool, UniqueKeyDict
@@ -461,7 +461,11 @@ def ctas_query(self, **render_kwarg: t.Any) -> exp.Query:
461461
)
462462
return query
463463

464-
def referenced_audits(self, audits: t.Dict[str, ModelAudit]) -> t.List[ModelAudit]:
464+
def referenced_audits(
465+
self,
466+
audits: t.Dict[str, ModelAudit],
467+
default_audits: t.List[AuditReference] = [],
468+
) -> t.List[ModelAudit]:
465469
"""Returns audits referenced in this model.
466470
467471
Args:
@@ -471,7 +475,7 @@ def referenced_audits(self, audits: t.Dict[str, ModelAudit]) -> t.List[ModelAudi
471475

472476
referenced_audits = []
473477

474-
for audit_name, _ in self.audits:
478+
for audit_name, _ in self.audits + default_audits:
475479
if audit_name in self.inline_audits:
476480
referenced_audits.append(self.inline_audits[audit_name])
477481
elif audit_name in audits:

sqlmesh/core/model/meta.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1111

1212
from sqlmesh.core import dialect as d
13-
from sqlmesh.core.dialect import normalize_model_name
13+
from sqlmesh.core.dialect import normalize_model_name, extract_audit
1414
from sqlmesh.core.model.common import (
1515
bool_validator,
1616
default_catalog_validator,
@@ -85,32 +85,12 @@ class ModelMeta(_Node):
8585

8686
@field_validator("audits", mode="before")
8787
def _audits_validator(cls, v: t.Any) -> t.Any:
88-
def extract(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
89-
kwargs = {}
90-
91-
if isinstance(v, exp.Anonymous):
92-
func = v.name
93-
args = v.expressions
94-
elif isinstance(v, exp.Func):
95-
func = v.sql_name()
96-
args = list(v.args.values())
97-
else:
98-
return v.name.lower(), {}
99-
100-
for arg in args:
101-
if not isinstance(arg, (exp.PropertyEQ, exp.EQ)):
102-
raise ConfigError(
103-
f"Function '{func}' must be called with key-value arguments like {func}(arg := value)."
104-
)
105-
kwargs[arg.left.name.lower()] = arg.right
106-
return func.lower(), kwargs
107-
10888
if isinstance(v, (exp.Tuple, exp.Array)):
109-
return [extract(i) for i in v.expressions]
89+
return [extract_audit(i) for i in v.expressions]
11090
if isinstance(v, exp.Paren):
111-
return [extract(v.this)]
91+
return [extract_audit(v.this)]
11292
if isinstance(v, exp.Expression):
113-
return [extract(v)]
93+
return [extract_audit(v)]
11494
if isinstance(v, list):
11595
audits = []
11696

sqlmesh/core/snapshot/definition.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
if t.TYPE_CHECKING:
4747
from sqlglot.dialects.dialect import DialectType
4848
from sqlmesh.core.environment import EnvironmentNamingInfo
49+
from sqlmesh.core.config import Config
4950

5051
Interval = t.Tuple[int, int]
5152
Intervals = t.List[Interval]
@@ -617,6 +618,7 @@ def from_node(
617618
version: t.Optional[str] = None,
618619
audits: t.Optional[t.Dict[str, ModelAudit]] = None,
619620
cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None,
621+
config: t.Optional[Config] = None,
620622
) -> Snapshot:
621623
"""Creates a new snapshot for a node.
622624
@@ -634,8 +636,13 @@ def from_node(
634636
"""
635637
created_ts = now_timestamp()
636638
kwargs = {}
639+
default_audits = (
640+
config.model_defaults.audits if (config and config.model_defaults.audits) else []
641+
)
637642
if node.is_model:
638-
kwargs["audits"] = tuple(t.cast(_Model, node).referenced_audits(audits or {}))
643+
kwargs["audits"] = tuple(
644+
t.cast(_Model, node).referenced_audits(audits or {}, default_audits)
645+
)
639646

640647
return cls(
641648
name=node.fqn,

tests/core/test_config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,3 +557,30 @@ def test_load_duckdb_attach_config(tmp_path):
557557
assert attach_config_2.type == "postgres"
558558
assert attach_config_2.path == "dbname=postgres user=postgres host=127.0.0.1"
559559
assert attach_config_2.read_only is True
560+
561+
562+
def test_load_model_defaults_audits(tmp_path):
563+
config_path = tmp_path / "config_model_defaults_audits.yaml"
564+
with open(config_path, "w", encoding="utf-8") as fd:
565+
fd.write(
566+
"""
567+
model_defaults:
568+
dialect: ''
569+
audits:
570+
- assert_positive_order_ids
571+
- does_not_exceed_threshold(column := id, threshold := 1000)
572+
"""
573+
)
574+
575+
config = load_config_from_paths(
576+
Config,
577+
project_paths=[config_path],
578+
)
579+
580+
assert len(config.model_defaults.audits) == 2
581+
assert config.model_defaults.audits[0] == ("assert_positive_order_ids", {})
582+
assert config.model_defaults.audits[1][0] == "does_not_exceed_threshold"
583+
assert type(config.model_defaults.audits[1][1]["column"]) == exp.Column
584+
assert config.model_defaults.audits[1][1]["column"].this.this == "id"
585+
assert type(config.model_defaults.audits[1][1]["threshold"]) == exp.Literal
586+
assert config.model_defaults.audits[1][1]["threshold"].this == "1000"

0 commit comments

Comments
 (0)