Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def state_sync(self) -> StateSync:

if self._state_sync.get_versions(validate=False).schema_version == 0:
self.console.log_status_update("Initializing new project state...")
self._state_sync.migrate(default_catalog=self.default_catalog)
self._state_sync.migrate()
self._state_sync.get_versions()
self._state_sync = CachingStateSync(self._state_sync) # type: ignore
return self._state_sync
Expand Down Expand Up @@ -2356,7 +2356,6 @@ def migrate(self) -> None:
self._load_materializations()
try:
self._new_state_sync().migrate(
default_catalog=self.default_catalog,
promoted_snapshots_only=self.config.migration.promoted_snapshots_only,
)
except Exception as e:
Expand Down
6 changes: 4 additions & 2 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ def _schema_version_validator(cls, v: t.Any) -> int:
return 0 if v is None else int(v)


MIN_SCHEMA_VERSION = 60
MIN_SQLMESH_VERSION = "0.134.0"
MIGRATIONS = [
importlib.import_module(f"sqlmesh.migrations.{migration}")
for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__))
]
SCHEMA_VERSION: int = len(MIGRATIONS)
# -1 to account for the baseline script
SCHEMA_VERSION: int = MIN_SCHEMA_VERSION + len(MIGRATIONS) - 1


class PromotionResult(PydanticModel):
Expand Down Expand Up @@ -469,7 +472,6 @@ def compact_intervals(self) -> None:
@abc.abstractmethod
def migrate(
self,
default_catalog: t.Optional[str],
skip_backup: bool = False,
promoted_snapshots_only: bool = True,
) -> None:
Expand Down
6 changes: 0 additions & 6 deletions sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pathlib import Path
from datetime import datetime

from sqlglot import exp

from sqlmesh.core.console import Console, get_console
from sqlmesh.core.engine_adapter import EngineAdapter
Expand Down Expand Up @@ -90,7 +89,6 @@ def __init__(
console: t.Optional[Console] = None,
cache_dir: Path = Path(),
):
self.plan_dags_table = exp.table_("_plan_dags", db=schema)
self.interval_state = IntervalState(engine_adapter, schema=schema)
self.environment_state = EnvironmentState(engine_adapter, schema=schema)
self.snapshot_state = SnapshotState(engine_adapter, schema=schema, cache_dir=cache_dir)
Expand All @@ -101,7 +99,6 @@ def __init__(
snapshot_state=self.snapshot_state,
environment_state=self.environment_state,
interval_state=self.interval_state,
plan_dags_table=self.plan_dags_table,
console=console,
)
# Make sure that if an empty string is provided that we treat it as None
Expand Down Expand Up @@ -308,7 +305,6 @@ def remove_state(self, including_backup: bool = False) -> None:
self.environment_state.environments_table,
self.environment_state.environment_statements_table,
self.interval_state.intervals_table,
self.plan_dags_table,
self.version_state.versions_table,
):
self.engine_adapter.drop_table(table)
Expand Down Expand Up @@ -453,14 +449,12 @@ def close(self) -> None:
@transactional()
def migrate(
self,
default_catalog: t.Optional[str],
skip_backup: bool = False,
promoted_snapshots_only: bool = True,
) -> None:
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
self.migrator.migrate(
self,
default_catalog,
skip_backup=skip_backup,
promoted_snapshots_only=promoted_snapshots_only,
)
Expand Down
31 changes: 19 additions & 12 deletions sqlmesh/core/state_sync/db/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
)
from sqlmesh.core.state_sync.base import (
MIGRATIONS,
MIN_SCHEMA_VERSION,
MIN_SQLMESH_VERSION,
)
from sqlmesh.core.state_sync.base import StateSync
from sqlmesh.core.state_sync.db.environment import EnvironmentState
Expand All @@ -41,7 +43,7 @@
from sqlmesh.utils import major_minor
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import now_timestamp
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.errors import SQLMeshError, StateMigrationError

logger = logging.getLogger(__name__)

Expand All @@ -61,7 +63,6 @@ def __init__(
snapshot_state: SnapshotState,
environment_state: EnvironmentState,
interval_state: IntervalState,
plan_dags_table: TableName,
console: t.Optional[Console] = None,
):
self.engine_adapter = engine_adapter
Expand All @@ -70,7 +71,6 @@ def __init__(
self.snapshot_state = snapshot_state
self.environment_state = environment_state
self.interval_state = interval_state
self.plan_dags_table = plan_dags_table

self._state_tables = [
self.snapshot_state.snapshots_table,
Expand All @@ -79,15 +79,13 @@ def __init__(
]
self._optional_state_tables = [
self.interval_state.intervals_table,
self.plan_dags_table,
self.snapshot_state.auto_restatements_table,
self.environment_state.environment_statements_table,
]

def migrate(
self,
state_sync: StateSync,
default_catalog: t.Optional[str],
skip_backup: bool = False,
promoted_snapshots_only: bool = True,
) -> None:
Expand All @@ -96,15 +94,13 @@ def migrate(
migration_start_ts = time.perf_counter()

try:
migrate_rows = self._apply_migrations(state_sync, default_catalog, skip_backup)
migrate_rows = self._apply_migrations(state_sync, skip_backup)

if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version:
return

if migrate_rows:
self._migrate_rows(promoted_snapshots_only)
# Cleanup plan DAGs since we currently don't migrate snapshot records that are in there.
self.engine_adapter.delete_from(self.plan_dags_table, "TRUE")
self.version_state.update_versions()

analytics.collector.on_migration_end(
Expand All @@ -126,6 +122,8 @@ def migrate(
)

self.console.log_migration_status(success=False)
if isinstance(e, StateMigrationError):
raise
raise SQLMeshError("SQLMesh migration failed.") from e

self.console.log_migration_status()
Expand Down Expand Up @@ -156,11 +154,20 @@ def rollback(self) -> None:
def _apply_migrations(
self,
state_sync: StateSync,
default_catalog: t.Optional[str],
skip_backup: bool,
) -> bool:
versions = self.version_state.get_versions()
migrations = MIGRATIONS[versions.schema_version :]
first_script_index = 0
if versions.schema_version and versions.schema_version < MIN_SCHEMA_VERSION:
raise StateMigrationError(
"The current state belongs to an old version of SQLMesh that is no longer supported. "
f"Please upgrade to {MIN_SQLMESH_VERSION} first before upgrading to {SQLMESH_VERSION}."
)
elif versions.schema_version > 0:
# -1 to skip the baseline migration script
first_script_index = versions.schema_version - (MIN_SCHEMA_VERSION - 1)

migrations = MIGRATIONS[first_script_index:]
should_backup = any(
[
migrations,
Expand All @@ -177,10 +184,10 @@ def _apply_migrations(

for migration in migrations:
logger.info(f"Applying migration {migration}")
migration.migrate_schemas(state_sync, default_catalog=default_catalog)
migration.migrate_schemas(state_sync)
if state_table_exist:
# No need to run DML for the initial migration since all tables are empty
migration.migrate_rows(state_sync, default_catalog=default_catalog)
migration.migrate_rows(state_sync)

snapshot_count_after = self.snapshot_state.count()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
"""Readds indexes and primary keys in case tables were restored from a backup."""
"""The baseline migration script that sets up the initial state tables."""

from sqlglot import exp
from sqlmesh.utils import random_id
from sqlmesh.utils.migration import index_text_type
from sqlmesh.utils.migration import blob_text_type
from sqlmesh.utils.migration import blob_text_type, index_text_type


def migrate_schemas(state_sync, **kwargs): # type: ignore
schema = state_sync.schema
engine_adapter = state_sync.engine_adapter
if not engine_adapter.SUPPORTS_INDEXES:
return

intervals_table = "_intervals"
snapshots_table = "_snapshots"
environments_table = "_environments"
versions_table = "_versions"
if state_sync.schema:
engine_adapter.create_schema(schema)
intervals_table = f"{schema}.{intervals_table}"
snapshots_table = f"{schema}.{snapshots_table}"
environments_table = f"{schema}.{environments_table}"

table_suffix = random_id(short=True)
versions_table = f"{schema}.{versions_table}"

index_type = index_text_type(engine_adapter.dialect)
blob_type = blob_text_type(engine_adapter.dialect)

new_snapshots_table = f"{snapshots_table}__{table_suffix}"
snapshots_columns_to_types = {
"name": exp.DataType.build(index_type),
"identifier": exp.DataType.build(index_type),
Expand All @@ -38,7 +34,6 @@ def migrate_schemas(state_sync, **kwargs): # type: ignore
"unrestorable": exp.DataType.build("boolean"),
}

new_environments_table = f"{environments_table}__{table_suffix}"
environments_columns_to_types = {
"name": exp.DataType.build(index_type),
"snapshots": exp.DataType.build(blob_type),
Expand All @@ -53,9 +48,9 @@ def migrate_schemas(state_sync, **kwargs): # type: ignore
"catalog_name_override": exp.DataType.build("text"),
"previous_finalized_snapshots": exp.DataType.build(blob_type),
"normalize_name": exp.DataType.build("boolean"),
"requirements": exp.DataType.build(blob_type),
}

new_intervals_table = f"{intervals_table}__{table_suffix}"
intervals_columns_to_types = {
"id": exp.DataType.build(index_type),
"created_ts": exp.DataType.build("bigint"),
Expand All @@ -69,53 +64,34 @@ def migrate_schemas(state_sync, **kwargs): # type: ignore
"is_compacted": exp.DataType.build("boolean"),
}

# Recreate the snapshots table and its indexes.
engine_adapter.create_table(
new_snapshots_table, snapshots_columns_to_types, primary_key=("name", "identifier")
)
engine_adapter.create_index(
new_snapshots_table, "_snapshots_name_version_idx", ("name", "version")
)
engine_adapter.insert_append(
new_snapshots_table,
exp.select("*").from_(snapshots_table),
target_columns_to_types=snapshots_columns_to_types,
)
versions_columns_to_types = {
"schema_version": exp.DataType.build("int"),
"sqlglot_version": exp.DataType.build(index_type),
"sqlmesh_version": exp.DataType.build(index_type),
}

# Recreate the environments table and its indexes.
engine_adapter.create_table(
new_environments_table, environments_columns_to_types, primary_key=("name",)
)
engine_adapter.insert_append(
new_environments_table,
exp.select("*").from_(environments_table),
target_columns_to_types=environments_columns_to_types,
# Create the versions table.
engine_adapter.create_state_table(versions_table, versions_columns_to_types)

# Create the snapshots table and its indexes.
engine_adapter.create_state_table(
snapshots_table, snapshots_columns_to_types, primary_key=("name", "identifier")
)
engine_adapter.create_index(snapshots_table, "_snapshots_name_version_idx", ("name", "version"))

# Recreate the intervals table and its indexes.
engine_adapter.create_table(
new_intervals_table, intervals_columns_to_types, primary_key=("id",)
# Create the environments table and its indexes.
engine_adapter.create_state_table(
environments_table, environments_columns_to_types, primary_key=("name",)
)
engine_adapter.create_index(
new_intervals_table, "_intervals_name_identifier_idx", ("name", "identifier")

# Create the intervals table and its indexes.
engine_adapter.create_state_table(
intervals_table, intervals_columns_to_types, primary_key=("id",)
)
engine_adapter.create_index(
new_intervals_table, "_intervals_name_version_idx", ("name", "version")
intervals_table, "_intervals_name_identifier_idx", ("name", "identifier")
)
engine_adapter.insert_append(
new_intervals_table,
exp.select("*").from_(intervals_table),
target_columns_to_types=intervals_columns_to_types,
)

# Drop old tables.
for table in (snapshots_table, environments_table, intervals_table):
engine_adapter.drop_table(table)

# Replace old tables with new ones.
engine_adapter.rename_table(new_snapshots_table, snapshots_table)
engine_adapter.rename_table(new_environments_table, environments_table)
engine_adapter.rename_table(new_intervals_table, intervals_table)
engine_adapter.create_index(intervals_table, "_intervals_name_version_idx", ("name", "version"))


def migrate_rows(state_sync, **kwargs): # type: ignore
Expand Down
64 changes: 0 additions & 64 deletions sqlmesh/migrations/v0001_init.py

This file was deleted.

Loading