Skip to content

Commit d50ee1a

Browse files
authored
feat: add python library versions to the environment (#3275)
1 parent 24f2a9a commit d50ee1a

File tree

6 files changed

+80
-0
lines changed

6 files changed

+80
-0
lines changed

sqlmesh/core/environment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class Environment(EnvironmentNamingInfo):
101101
promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment
102102
(i.e. for which the views are created). If not specified, all snapshots are promoted.
103103
previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized.
104+
requirements: A mapping of library versions for all the snapshots in this environment.
104105
"""
105106

106107
snapshots_: t.List[t.Any] = Field(alias="snapshots")
@@ -116,6 +117,7 @@ class Environment(EnvironmentNamingInfo):
116117
previous_finalized_snapshots_: t.Optional[t.List[t.Any]] = Field(
117118
default=None, alias="previous_finalized_snapshots"
118119
)
120+
requirements: t.Dict[str, str] = {}
119121

120122
@field_validator("snapshots_", "previous_finalized_snapshots_", mode="before")
121123
@classmethod
@@ -135,6 +137,12 @@ def _load_snapshot_ids(cls, v: str | t.List[t.Any] | None) -> t.List[t.Any] | No
135137
raise ValueError("Must be a list of SnapshotId dicts or objects")
136138
return v
137139

140+
@field_validator("requirements", mode="before")
141+
def _load_requirements(cls, v: t.Any) -> t.Any:
142+
if isinstance(v, str):
143+
v = json.loads(v)
144+
return v or {}
145+
138146
@property
139147
def snapshots(self) -> t.List[SnapshotTableInfo]:
140148
return self._convert_list_to_models_and_store("snapshots_", SnapshotTableInfo)

sqlmesh/core/plan/definition.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import logging
4+
import sys
35
import typing as t
46
from dataclasses import dataclass
57
from datetime import datetime
@@ -26,9 +28,20 @@
2628
from sqlmesh.utils.date import TimeLike, now, to_datetime, to_timestamp
2729
from sqlmesh.utils.pydantic import PydanticModel
2830

31+
logger = logging.getLogger(__name__)
32+
2933
SnapshotMapping = t.Dict[SnapshotId, t.Set[SnapshotId]]
3034

3135

36+
if sys.version_info >= (3, 12):
37+
from importlib import metadata
38+
else:
39+
import importlib_metadata as metadata # type: ignore
40+
41+
42+
IGNORED_PACKAGES = {"sqlmesh", "sqlglot"}
43+
44+
3245
class Plan(PydanticModel, frozen=True):
3346
context_diff: ContextDiff
3447
plan_id: str
@@ -209,6 +222,23 @@ def environment(self) -> Environment:
209222
else self.context_diff.previous_finalized_snapshots
210223
)
211224

225+
requirements = {}
226+
distributions = metadata.packages_distributions()
227+
228+
for snapshot in self.context_diff.snapshots.values():
229+
if snapshot.is_model:
230+
for executable in snapshot.model.python_env.values():
231+
if executable.kind == "import":
232+
try:
233+
start = "from " if executable.payload.startswith("from ") else "import "
234+
lib = executable.payload.split(start)[1].split()[0].split(".")[0]
235+
if lib in distributions:
236+
for dist in distributions[lib]:
237+
if dist not in requirements and dist not in IGNORED_PACKAGES:
238+
requirements[dist] = metadata.version(dist)
239+
except metadata.PackageNotFoundError:
240+
logger.warning("Failed to find package for %s", lib)
241+
212242
return Environment(
213243
snapshots=snapshots,
214244
start_at=self.provided_start or self._earliest_interval_start,
@@ -218,6 +248,7 @@ def environment(self) -> Environment:
218248
expiration_ts=expiration_ts,
219249
promoted_snapshot_ids=promoted_snapshot_ids,
220250
previous_finalized_snapshots=previous_finalized_snapshots,
251+
requirements=requirements,
221252
**self.environment_naming_info.dict(),
222253
)
223254

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def __init__(
152152
"catalog_name_override": exp.DataType.build("text"),
153153
"previous_finalized_snapshots": exp.DataType.build("text"),
154154
"normalize_name": exp.DataType.build("boolean"),
155+
"requirements": exp.DataType.build("text"),
155156
}
156157

157158
self._interval_columns_to_types = {
@@ -1711,6 +1712,7 @@ def _environment_to_df(environment: Environment) -> pd.DataFrame:
17111712
else None
17121713
),
17131714
"normalize_name": environment.normalize_name,
1715+
"requirements": json.dumps(environment.requirements),
17141716
}
17151717
]
17161718
)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Add requirements to environments table"""
2+
3+
from sqlglot import exp
4+
5+
6+
def migrate(state_sync, **kwargs): # type: ignore
7+
engine_adapter = state_sync.engine_adapter
8+
environments_table = "_environments"
9+
if state_sync.schema:
10+
environments_table = f"{state_sync.schema}.{environments_table}"
11+
12+
alter_table_exp = exp.Alter(
13+
this=exp.to_table(environments_table),
14+
kind="TABLE",
15+
actions=[
16+
exp.ColumnDef(
17+
this=exp.to_column("requirements"),
18+
kind=exp.DataType.build("text"),
19+
)
20+
],
21+
)
22+
23+
engine_adapter.execute(alter_table_exp)

tests/core/test_plan.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
yesterday_ds,
4242
)
4343
from sqlmesh.utils.errors import PlanError
44+
from sqlmesh.utils.metaprogramming import Executable
4445

4546

4647
def test_forward_only_plan_sets_version(make_snapshot, mocker: MockerFixture):
@@ -2562,3 +2563,17 @@ def test_interval_end_per_model(make_snapshot):
25622563
is_dev=True,
25632564
)
25642565
assert plan_builder.build().interval_end_per_model is None
2566+
2567+
2568+
def test_plan_requirements():
2569+
context = Context(paths="examples/sushi")
2570+
model = context.get_model("sushi.items")
2571+
model.python_env["ruamel"] = Executable(payload="import ruamel", kind="import")
2572+
model.python_env["Image"] = Executable(
2573+
payload="from ipywidgets.widgets.widget_media import Image", kind="import"
2574+
)
2575+
2576+
plan = context.plan(
2577+
"dev", no_prompts=True, skip_tests=True, skip_backfill=True
2578+
).environment.requirements
2579+
assert set(plan) == {"ipywidgets", "numpy", "pandas", "ruamel.yaml", "ruamel.yaml.clib"}

tests/schedulers/airflow/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot):
172172
],
173173
"suffix_target": "schema",
174174
"normalize_name": True,
175+
"requirements": {},
175176
},
176177
"no_gaps": False,
177178
"skip_backfill": False,

0 commit comments

Comments
 (0)