Skip to content

Commit ec0618d

Browse files
Feat: Add support for defaults in user audits (#2901)
1 parent fb9d39d commit ec0618d

File tree

3 files changed

+67
-2
lines changed

3 files changed

+67
-2
lines changed

docs/concepts/audits.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ Notice how `column` and `threshold` parameters have been set. These values will
7575

7676
Note that the same audit can be applied more than once to the a model using different sets of parameters.
7777

78+
Generic audits can define default values as follows:
79+
```sql linenums="1"
80+
AUDIT (
81+
name does_not_exceed_threshold,
82+
defaults (
83+
threshold = 10,
84+
column = id
85+
)
86+
);
87+
SELECT * FROM @this_model
88+
WHERE @column >= @threshold;
89+
```
90+
7891
### Naming
7992
We recommended avoiding SQL keywords when naming audit parameters. Quote any audit argument that is also a SQL keyword.
8093

sqlmesh/core/audit/definition.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sqlmesh.utils.pydantic import (
3535
PydanticModel,
3636
field_validator,
37+
get_dialect,
3738
model_validator,
3839
model_validator_v1_args,
3940
)
@@ -148,11 +149,19 @@ def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]:
148149

149150

150151
@field_validator("defaults", mode="before", check_fields=False)
151-
def audit_map_validator(cls: t.Type, v: t.Any) -> t.Dict[str, t.Any]:
152+
def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.Any]:
153+
if isinstance(v, exp.Paren):
154+
return dict([_maybe_parse_arg_pair(v.unnest())])
152155
if isinstance(v, (exp.Tuple, exp.Array)):
153156
return dict(map(_maybe_parse_arg_pair, v.expressions))
154157
elif isinstance(v, dict):
155-
return v
158+
dialect = get_dialect(values)
159+
return {
160+
key: value
161+
if isinstance(value, exp.Expression)
162+
else d.parse_one(str(value), dialect=dialect)
163+
for key, value in v.items()
164+
}
156165
else:
157166
raise_config_error(
158167
"Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError

tests/core/test_audit.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,49 @@ def test_load_multiple(assert_exp_eq):
303303
)
304304

305305

306+
def test_load_with_dictionary_defaults():
307+
expressions = parse(
308+
"""
309+
AUDIT (
310+
name my_audit,
311+
dialect spark,
312+
defaults (
313+
field1 = some_column,
314+
field2 = 3
315+
),
316+
);
317+
318+
SELECT 1
319+
"""
320+
)
321+
322+
audit = load_audit(expressions, dialect="spark")
323+
assert audit.defaults.keys() == {"field1", "field2"}
324+
for value in audit.defaults.values():
325+
assert isinstance(value, exp.Expression)
326+
327+
328+
def test_load_with_single_defaults():
329+
# testing it also works with a single default with no trailing comma
330+
expressions = parse(
331+
"""
332+
AUDIT (
333+
name my_audit,
334+
defaults (
335+
field1 = some_column
336+
),
337+
);
338+
339+
SELECT 1
340+
"""
341+
)
342+
343+
audit = load_audit(expressions, dialect="duckdb")
344+
assert audit.defaults.keys() == {"field1"}
345+
for value in audit.defaults.values():
346+
assert isinstance(value, exp.Expression)
347+
348+
306349
def test_no_audit_statement():
307350
expressions = parse(
308351
"""

0 commit comments

Comments
 (0)