Skip to content

Commit d76bc0f

Browse files
authored
Fix!: Add the missing migration script for the gateway variable (#2433)
1 parent c2872f6 commit d76bc0f

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Move the gateway variable."""
2+
3+
import ast
4+
import json
5+
6+
import pandas as pd
7+
from sqlglot import exp
8+
9+
from sqlmesh.utils.migration import index_text_type
10+
11+
12+
def migrate(state_sync, **kwargs): # type: ignore
13+
engine_adapter = state_sync.engine_adapter
14+
schema = state_sync.schema
15+
snapshots_table = "_snapshots"
16+
if schema:
17+
snapshots_table = f"{schema}.{snapshots_table}"
18+
19+
migration_needed = False
20+
new_snapshots = []
21+
22+
for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall(
23+
exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_(
24+
snapshots_table
25+
),
26+
quote_identifiers=True,
27+
):
28+
parsed_snapshot = json.loads(snapshot)
29+
python_env = parsed_snapshot["node"].get("python_env")
30+
if python_env:
31+
gateway = python_env.pop("gateway", None)
32+
if gateway is not None:
33+
migration_needed = True
34+
sqlmesh_vars = {"gateway": ast.literal_eval(gateway["payload"])}
35+
python_env["__sqlmesh__vars__"] = {
36+
"payload": repr(sqlmesh_vars),
37+
"kind": "value",
38+
}
39+
40+
new_snapshots.append(
41+
{
42+
"name": name,
43+
"identifier": identifier,
44+
"version": version,
45+
"snapshot": json.dumps(parsed_snapshot),
46+
"kind_name": kind_name,
47+
"expiration_ts": expiration_ts,
48+
}
49+
)
50+
51+
if migration_needed and new_snapshots:
52+
engine_adapter.delete_from(snapshots_table, "TRUE")
53+
54+
index_type = index_text_type(engine_adapter.dialect)
55+
56+
engine_adapter.insert_append(
57+
snapshots_table,
58+
pd.DataFrame(new_snapshots),
59+
columns_to_types={
60+
"name": exp.DataType.build(index_type),
61+
"identifier": exp.DataType.build(index_type),
62+
"version": exp.DataType.build(index_type),
63+
"snapshot": exp.DataType.build("text"),
64+
"kind_name": exp.DataType.build(index_type),
65+
"expiration_ts": exp.DataType.build("bigint"),
66+
},
67+
)

0 commit comments

Comments
 (0)