Skip to content

Commit 0d86edc

Browse files
authored
feat: add the ability to have audits on external models (#2715)
1 parent 6758651 commit 0d86edc

File tree

11 files changed

+133
-29
lines changed

11 files changed

+133
-29
lines changed

docs/concepts/models/external_models.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,24 @@ Files in the `external_models` directory must be `.yaml` files that follow the s
9191

9292
When SQLMesh loads the definitions, it will first load the models defined in `external_models.yaml` (or `schema.yaml`) and any models found in `external_models/*.yaml`.
9393

94-
Therefore, you can use `sqlmesh create_external_models` to manage the `external_models.yaml` file and then put any models that need to be defined manually inside the `external_models/` directory.
94+
Therefore, you can use `sqlmesh create_external_models` to manage the `external_models.yaml` file and then put any models that need to be defined manually inside the `external_models/` directory.
95+
96+
### External Audits
97+
It is possible to define [audits](../audits.md) on external models. This can be useful to check the data quality of upstream dependencies before your internal models evaluate.
98+
99+
This example shows an external model with two audits.
100+
101+
```yaml
102+
- name: raw.demographics
103+
description: Table containing demographics information
104+
audits:
105+
- name: not_null
106+
columns: "[customer_id]"
107+
- name: accepted_range
108+
column: zip
109+
min_v: "'00000'"
110+
max_v: "'99999'"
111+
columns:
112+
customer_id: int
113+
zip: text
114+
```

examples/sushi/config.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@
2424
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
2525

2626

27+
defaults = {"dialect": "duckdb"}
28+
model_defaults = ModelDefaultsConfig(**defaults)
29+
model_defaults_iceberg = ModelDefaultsConfig(**defaults, storage_format="iceberg")
30+
31+
2732
# An in memory DuckDB config.
2833
config = Config(
2934
default_connection=DuckDBConnectionConfig(),
30-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
35+
model_defaults=model_defaults,
3136
)
3237

3338
bigquery_config = Config(
@@ -38,21 +43,21 @@
3843
)
3944
},
4045
default_gateway="bq",
41-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
46+
model_defaults=model_defaults,
4247
)
4348

4449
# A configuration used for SQLMesh tests.
4550
test_config = Config(
4651
gateways={"in_memory": GatewayConfig(connection=DuckDBConnectionConfig())},
4752
default_gateway="in_memory",
4853
plan=PlanConfig(auto_categorize_changes=CategorizerConfig(sql=AutoCategorizationMode.SEMI)),
49-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
54+
model_defaults=model_defaults,
5055
)
5156

5257
# A stateful DuckDB config.
5358
local_config = Config(
5459
default_connection=DuckDBConnectionConfig(database=f"{DATA_DIR}/local.duckdb"),
55-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
60+
model_defaults=model_defaults,
5661
)
5762

5863
airflow_config = Config(
@@ -65,21 +70,21 @@
6570
},
6671
)
6772
),
68-
model_defaults=ModelDefaultsConfig(dialect="duckdb", storage_format="iceberg"),
73+
model_defaults=model_defaults_iceberg,
6974
)
7075

7176

7277
airflow_config_docker = Config(
7378
default_scheduler=AirflowSchedulerConfig(airflow_url="http://airflow-webserver:8080/"),
7479
gateways=GatewayConfig(connection=SparkConnectionConfig()),
75-
model_defaults=ModelDefaultsConfig(dialect="duckdb", storage_format="iceberg"),
80+
model_defaults=model_defaults_iceberg,
7681
)
7782

7883
# A DuckDB config with a physical schema map.
7984
map_config = Config(
8085
default_connection=DuckDBConnectionConfig(),
8186
physical_schema_override={"sushi": "company_internal"},
82-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
87+
model_defaults=model_defaults,
8388
)
8489

8590

@@ -114,13 +119,13 @@
114119
],
115120
),
116121
],
117-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
122+
model_defaults=model_defaults,
118123
)
119124

120125

121126
environment_suffix_config = Config(
122127
default_connection=DuckDBConnectionConfig(),
123-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
128+
model_defaults=model_defaults,
124129
environment_suffix_target=EnvironmentSuffixTarget.TABLE,
125130
)
126131

@@ -133,7 +138,7 @@
133138
local_catalogs = Config(
134139
default_connection=DuckDBConnectionConfig(catalogs=CATALOGS),
135140
default_test_connection=DuckDBConnectionConfig(catalogs=CATALOGS),
136-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
141+
model_defaults=model_defaults,
137142
)
138143

139144
environment_catalog_mapping_config = Config(
@@ -144,7 +149,7 @@
144149
"dev_catalog": ":memory:",
145150
}
146151
),
147-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
152+
model_defaults=model_defaults,
148153
environment_suffix_target=EnvironmentSuffixTarget.TABLE,
149154
environment_catalog_mapping={
150155
"^prod$": "prod_catalog",
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
- name: raw.demographics
22
description: Table containing demographics information
3+
start: 1 week ago
4+
audits:
5+
- name: not_null
6+
columns: "[customer_id]"
7+
- name: accepted_range
8+
column: zip
9+
min_v: "'00000'"
10+
max_v: "'99999'"
311
columns:
412
customer_id: int
513
zip: text

sqlmesh/core/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def _load_external_models(self) -> UniqueKeyDict[str, Model]:
203203
model = create_external_model(
204204
**row,
205205
dialect=config.model_defaults.dialect,
206+
defaults=config.model_defaults.dict(),
206207
path=path,
207208
project=config.project,
208209
default_catalog=self._context.default_catalog,

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,7 @@ def create_external_model(
17911791
*,
17921792
dialect: t.Optional[str] = None,
17931793
path: Path = Path(),
1794+
defaults: t.Optional[t.Dict[str, t.Any]] = None,
17941795
**kwargs: t.Any,
17951796
) -> Model:
17961797
"""Creates an external model.
@@ -1804,6 +1805,7 @@ def create_external_model(
18041805
return _create_model(
18051806
ExternalModel,
18061807
name,
1808+
defaults=defaults,
18071809
dialect=dialect,
18081810
path=path,
18091811
kind=ModelKindName.EXTERNAL.value,

sqlmesh/core/model/meta.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,29 @@ def extract(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]:
110110
if isinstance(v, exp.Expression):
111111
return [extract(v)]
112112
if isinstance(v, list):
113-
return [
114-
(
115-
entry[0].lower(),
116-
{
117-
key: d.parse(value)[0] if isinstance(value, str) else value
118-
for key, value in entry[1].items()
119-
},
113+
audits = []
114+
115+
for entry in v:
116+
if isinstance(entry, dict):
117+
args = entry
118+
name = entry.pop("name")
119+
elif isinstance(entry, (tuple, list)):
120+
name, args = entry
121+
else:
122+
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
123+
124+
audits.append(
125+
(
126+
name.lower(),
127+
{
128+
key: d.parse_one(value) if isinstance(value, str) else value
129+
for key, value in args.items()
130+
},
131+
)
120132
)
121-
for entry in v
122-
]
133+
134+
return audits
135+
123136
return v
124137

125138
@field_validator("tags", mode="before")

sqlmesh/core/snapshot/definition.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,9 @@ def _table_name(self, version: str, is_deployable: bool) -> str:
319319
version: The snapshot version.
320320
is_deployable: Indicates whether to return the table name for deployment to production.
321321
"""
322+
if self.is_external:
323+
return self.name
324+
322325
is_dev_table = not is_deployable
323326
if is_dev_table:
324327
version = self.temp_version_get_or_generate()
@@ -742,6 +745,11 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
742745
for start, end in other.dev_intervals:
743746
self.add_interval(start, end, is_dev=True)
744747

748+
@property
749+
def evaluatable(self) -> bool:
750+
"""Whether or not a snapshot should be evaluated and have intervals."""
751+
return bool(not self.is_symbolic or self.model.audits)
752+
745753
def missing_intervals(
746754
self,
747755
start: TimeLike,
@@ -761,7 +769,6 @@ def missing_intervals(
761769
start: The start date/time of the interval (inclusive)
762770
end: The end date/time of the interval (inclusive if the type is date, exclusive otherwise)
763771
execution_time: The date/time time reference to use for execution time. Defaults to now.
764-
restatements: A set of snapshot names being restated
765772
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
766773
ignore_cron: Whether to ignore the node's cron schedule.
767774
end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback,
@@ -795,7 +802,7 @@ def missing_intervals(
795802
self.intervals if deployability_index.is_representative(self) else self.dev_intervals
796803
)
797804

798-
if self.is_symbolic or (self.is_seed and intervals):
805+
if not self.evaluatable or (self.is_seed and intervals):
799806
return []
800807

801808
allow_partials = not end_bounded and self.is_model and self.model.allow_partials
@@ -1502,7 +1509,7 @@ def missing_intervals(
15021509
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
15031510

15041511
for snapshot in snapshots:
1505-
if snapshot.is_symbolic:
1512+
if not snapshot.evaluatable:
15061513
continue
15071514
interval = restatements.get(snapshot.snapshot_id)
15081515
snapshot_start_date = start_dt

tests/core/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,9 +2084,9 @@ def validate_state_sync_environment(
20842084
def validate_tables(snapshots: t.Iterable[Snapshot], context: Context) -> None:
20852085
adapter = context.engine_adapter
20862086
for snapshot in snapshots:
2087-
if not snapshot.is_model:
2087+
if not snapshot.is_model or snapshot.is_external:
20882088
continue
2089-
table_should_exist = not snapshot.is_symbolic
2089+
table_should_exist = not snapshot.is_embedded
20902090
assert adapter.table_exists(snapshot.table_name()) == table_should_exist
20912091
if table_should_exist:
20922092
assert select_all(snapshot.table_name(), adapter)

tests/core/test_scheduler.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import pytest
44
from pytest_mock.plugin import MockerFixture
5-
from sqlglot import parse_one
5+
from sqlglot import parse_one, parse
66

77
from sqlmesh.core.context import Context
88
from sqlmesh.core.environment import EnvironmentNamingInfo
9+
from sqlmesh.core.model import load_sql_based_model
910
from sqlmesh.core.model.definition import SqlModel
1011
from sqlmesh.core.model.kind import (
1112
IncrementalByTimeRangeKind,
@@ -14,7 +15,7 @@
1415
)
1516
from sqlmesh.core.node import IntervalUnit
1617
from sqlmesh.core.scheduler import Scheduler, compute_interval_params
17-
from sqlmesh.core.snapshot import Snapshot, SnapshotEvaluator
18+
from sqlmesh.core.snapshot import Snapshot, SnapshotEvaluator, SnapshotChangeCategory
1819
from sqlmesh.utils.date import to_datetime
1920
from sqlmesh.utils.errors import CircuitBreakerError
2021

@@ -417,3 +418,43 @@ def test_intervals_with_end_date_on_model(mocker: MockerFixture, make_snapshot):
417418
# generate for future days to ensure no future batches are loaded
418419
snapshot_to_batches = scheduler.batches(start="2023-02-01", end="2023-02-28")
419420
assert len(snapshot_to_batches) == 0
421+
422+
423+
def test_external_model_audit(mocker, make_snapshot):
424+
model = load_sql_based_model(
425+
parse( # type: ignore
426+
"""
427+
MODEL (
428+
name test_schema.test_model,
429+
kind EXTERNAL,
430+
columns (id int),
431+
audits not_null(columns := id)
432+
);
433+
434+
SELECT 1;
435+
"""
436+
),
437+
)
438+
439+
snapshot = make_snapshot(model)
440+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
441+
442+
evaluator = SnapshotEvaluator(adapter=mocker.MagicMock())
443+
spy = mocker.spy(evaluator, "_audit")
444+
445+
scheduler = Scheduler(
446+
snapshots=[snapshot],
447+
snapshot_evaluator=evaluator,
448+
state_sync=mocker.MagicMock(),
449+
max_workers=2,
450+
default_catalog=None,
451+
)
452+
453+
scheduler.run(
454+
EnvironmentNamingInfo(),
455+
"2022-01-01",
456+
"2022-01-01",
457+
"2022-01-30",
458+
)
459+
460+
spy.assert_called_once()

tests/core/test_snapshot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,3 +1746,10 @@ def test_missing_intervals_node_start_end(make_snapshot):
17461746
assert missing_intervals([snapshot], start="2024-03-01", end=to_datetime("2024-03-10")) == {}
17471747
assert missing_intervals([snapshot], start="2024-03-13", end="2024-03-14") == {}
17481748
assert missing_intervals([snapshot], start="2024-03-14", end="2024-03-30") == {}
1749+
1750+
1751+
def test_external_model_audits(sushi_context):
1752+
snapshot = sushi_context.get_snapshot("raw.demographics")
1753+
assert snapshot.evaluatable
1754+
assert len(snapshot.model.audits) == 2
1755+
assert snapshot.intervals

0 commit comments

Comments
 (0)