Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,14 +1593,14 @@ def _get_data_objects(
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
defaultdict(lambda: defaultdict(set))
)
snapshots_by_table_name: t.Dict[str, Snapshot] = {}
snapshots_by_table_name: t.Dict[exp.Table, t.Dict[str, Snapshot]] = defaultdict(dict)
for snapshot in target_snapshots:
if not snapshot.is_model or snapshot.is_symbolic:
continue
table = table_name_callable(snapshot)
table_schema = d.schema_(table.db, catalog=table.catalog)
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
snapshots_by_table_name[table.name] = snapshot
snapshots_by_table_name[table_schema][table.name] = snapshot

def _get_data_objects_in_schema(
schema: exp.Table,
Expand All @@ -1613,23 +1613,25 @@ def _get_data_objects_in_schema(
)

with self.concurrent_context():
existing_objects: t.List[DataObject] = []
snapshot_id_to_obj: t.Dict[SnapshotId, DataObject] = {}
# A schema can be shared across multiple engines, so we need to group tables by both gateway and schema
for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
objs_for_gateway = [
obj
for objs in concurrent_apply_to_values(
list(tables_by_schema),
lambda s: _get_data_objects_in_schema(
schema=s, object_names=tables_by_schema.get(s), gateway=gateway
),
self.ddl_concurrent_tasks,
)
for obj in objs
]
existing_objects.extend(objs_for_gateway)
schema_list = list(tables_by_schema.keys())
results = concurrent_apply_to_values(
schema_list,
lambda s: _get_data_objects_in_schema(
schema=s, object_names=tables_by_schema.get(s), gateway=gateway
),
self.ddl_concurrent_tasks,
)

for schema, objs in zip(schema_list, results):
snapshots_by_name = snapshots_by_table_name.get(schema, {})
for obj in objs:
if obj.name in snapshots_by_name:
snapshot_id_to_obj[snapshots_by_name[obj.name].snapshot_id] = obj

return {snapshots_by_table_name[obj.name].snapshot_id: obj for obj in existing_objects}
return snapshot_id_to_obj


def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy:
Expand Down
16 changes: 8 additions & 8 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ def columns(table_name):
"get_data_objects",
return_value=[
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model__{snapshot.version}",
type="table",
)
Expand Down Expand Up @@ -1500,7 +1500,7 @@ def test_migrate_view(
"sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects",
return_value=[
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model__{snapshot.version}",
type="view",
)
Expand Down Expand Up @@ -1950,7 +1950,7 @@ def columns(table_name):
"sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects",
return_value=[
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model__{snapshot.version}",
type=DataObjectType.TABLE,
)
Expand Down Expand Up @@ -2037,7 +2037,7 @@ def columns(table_name):
"sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects",
return_value=[
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model__{snapshot.version}",
type=DataObjectType.TABLE,
)
Expand Down Expand Up @@ -4016,7 +4016,7 @@ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_moc

adapter_mock.get_data_objects.return_value = [
DataObject(
schema="test_schema",
schema="sqlmesh__db",
name=f"db__model__{new_snapshot.version}",
type=DataObjectType.TABLE,
)
Expand Down Expand Up @@ -4154,7 +4154,7 @@ def test_migrate_managed(adapter_mock, make_snapshot, mocker: MockerFixture):

adapter_mock.get_data_objects.return_value = [
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model__{snapshot.version}",
type=DataObjectType.MANAGED_TABLE,
)
Expand Down Expand Up @@ -4380,12 +4380,12 @@ def columns(table_name):
"sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects",
return_value=[
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model__{snapshot_1.version}",
type=DataObjectType.TABLE,
),
DataObject(
schema="test_schema",
schema="sqlmesh__test_schema",
name=f"test_schema__test_model_2__{snapshot_2.version}",
type=DataObjectType.TABLE,
),
Expand Down