Skip to content

Commit d2c4025

Browse files
authored
Fix!: Snowflake adapter (#2870)
1 parent 7c1ed83 commit d2c4025

File tree

6 files changed

+224
-39
lines changed

6 files changed

+224
-39
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,13 +2055,15 @@ def _get_data_objects(
20552055
"""
20562056
raise NotImplementedError()
20572057

2058-
def _get_temp_table(self, table: TableName, table_only: bool = False) -> exp.Table:
2058+
def _get_temp_table(
2059+
self, table: TableName, table_only: bool = False, quoted: bool = True
2060+
) -> exp.Table:
20592061
"""
20602062
Returns the name of the temp table that should be used for the given table name.
20612063
"""
20622064
table = t.cast(exp.Table, exp.to_table(table).copy())
20632065
table.set(
2064-
"this", exp.to_identifier(f"__temp_{table.name}_{random_id(short=True)}", quoted=True)
2066+
"this", exp.to_identifier(f"__temp_{table.name}_{random_id(short=True)}", quoted=quoted)
20652067
)
20662068

20672069
if table_only:

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ def _df_to_source_queries(
200200
batch_size: int,
201201
target_table: TableName,
202202
) -> t.List[SourceQuery]:
203-
temp_table = self._get_temp_table(target_table or "pandas")
203+
temp_table = self._get_temp_table(
204+
target_table or "pandas", quoted=False
205+
) # write_pandas() re-quotes everything without checking if its already quoted
204206

205207
def query_factory() -> Query:
206208
if snowpark and isinstance(df, snowpark.dataframe.DataFrame):
@@ -211,10 +213,10 @@ def query_factory() -> Query:
211213
# Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034
212214
# The above issue has already been fixed upstream, but we keep the following
213215
# line anyway in order to support a wider range of Snowflake versions.
214-
schema = f'"{temp_table.db}"'
216+
schema = temp_table.db
215217
if temp_table.catalog:
216-
schema = f'"{temp_table.catalog}".{schema}'
217-
self.cursor.execute(f"USE SCHEMA {schema}")
218+
schema = f"{temp_table.catalog}.{schema}"
219+
self.set_current_schema(schema)
218220

219221
# See: https://stackoverflow.com/a/75627721
220222
for column, kind in columns_to_types.items():
@@ -240,10 +242,14 @@ def query_factory() -> Query:
240242
df,
241243
temp_table.name,
242244
schema=temp_table.db or None,
243-
database=temp_table.catalog or None,
245+
database=normalize_identifiers(temp_table.catalog, dialect=self.dialect).sql(
246+
dialect=self.dialect
247+
)
248+
if temp_table.catalog
249+
else None,
244250
chunk_size=self.DEFAULT_BATCH_SIZE,
245251
overwrite=True,
246-
table_type="temp", # if you dont have this, it will convert the table we created above into a normal table and it wont get dropped when the session ends
252+
table_type="temp",
247253
)
248254
else:
249255
raise SQLMeshError(
@@ -252,7 +258,13 @@ def query_factory() -> Query:
252258

253259
return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table)
254260

255-
return [SourceQuery(query_factory=query_factory)]
261+
# the cleanup_func technically isnt needed because the temp table gets dropped when the session ends
262+
# but boy does it make our multi-adapter integration tests easier to write
263+
return [
264+
SourceQuery(
265+
query_factory=query_factory, cleanup_func=lambda: self.drop_table(temp_table)
266+
)
267+
]
256268

257269
def _fetch_native_df(
258270
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
@@ -280,6 +292,7 @@ def _get_data_objects(
280292

281293
schema = to_schema(schema_name)
282294
catalog_name = schema.catalog or self.get_current_catalog()
295+
283296
query = (
284297
exp.select(
285298
exp.column("TABLE_CATALOG").as_("catalog"),
@@ -308,6 +321,8 @@ def _get_data_objects(
308321
)
309322
.from_(exp.table_("TABLES", db="INFORMATION_SCHEMA", catalog=catalog_name))
310323
.where(exp.column("TABLE_SCHEMA").eq(schema.db))
324+
# Snowflake seems to have delayed internal metadata updates and will sometimes return duplicates
325+
.distinct()
311326
)
312327
if object_names:
313328
query = query.where(exp.column("TABLE_NAME").isin(*object_names))
@@ -328,10 +343,49 @@ def _get_data_objects(
328343
def set_current_catalog(self, catalog: str) -> None:
329344
self.execute(exp.Use(this=exp.to_identifier(catalog)))
330345

346+
def set_current_schema(self, schema: str) -> None:
347+
self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema)))
348+
349+
def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
350+
# note: important to use self._default_catalog instead of the self.default_catalog property
351+
# otherwise we get RecursionError: maximum recursion depth exceeded
352+
# because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc
353+
if self._default_catalog:
354+
# the purpose of this function is to identify instances where the default catalog is being used
355+
# (so that we can replace it with the actual catalog as specified in the gateway)
356+
#
357+
# we can't do a direct string comparison because the catalog value on the model
358+
# gets changed when it's normalized as part of generating `model.fqn`
359+
def unquote_and_lower(identifier: str) -> str:
360+
return exp.parse_identifier(identifier).name.lower()
361+
362+
default_catalog_unquoted = unquote_and_lower(self._default_catalog)
363+
default_catalog_normalized = normalize_identifiers(
364+
self._default_catalog, dialect=self.dialect
365+
)
366+
367+
def catalog_rewriter(node: exp.Expression) -> exp.Expression:
368+
if isinstance(node, exp.Table):
369+
if node.catalog:
370+
# only replace the catalog on the model with the target catalog if the two are functionally equivalent
371+
if unquote_and_lower(node.catalog) == default_catalog_unquoted:
372+
node.set("catalog", default_catalog_normalized)
373+
elif isinstance(node, exp.Use) and isinstance(node.this, exp.Identifier):
374+
if unquote_and_lower(node.this.output_name) == default_catalog_unquoted:
375+
node.set("this", default_catalog_normalized)
376+
return node
377+
378+
# Rewrite whatever default catalog is present on the query to be compatible with what the user supplied in the
379+
# Snowflake connection config. This is because the catalog present on the model gets normalized and quoted to match
380+
# the source dialect, which isnt always compatible with Snowflake
381+
expression = expression.transform(catalog_rewriter)
382+
383+
return super()._to_sql(expression=expression, quote=quote, **kwargs)
384+
331385
def _build_create_comment_column_exp(
332386
self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE"
333387
) -> exp.Comment | str:
334-
table_sql = table.sql(dialect=self.dialect, identify=True)
388+
table_sql = self._to_sql(table) # so that catalog replacement happens
335389
column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True)
336390

337391
truncated_comment = self._truncate_column_comment(column_comment)

sqlmesh/core/engine_adapter/spark.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,7 @@ def _ensure_pyspark_df(
284284
return self.spark.createDataFrame(df, **kwargs) # type: ignore
285285

286286
def _get_temp_table(
287-
self,
288-
table: TableName,
289-
table_only: bool = False,
287+
self, table: TableName, table_only: bool = False, quoted: bool = True
290288
) -> exp.Table:
291289
"""
292290
Returns the name of the temp table that should be used for the given table name.

tests/conftest.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from sqlglot import exp, maybe_parse, parse_one
1818
from sqlglot.dialects.dialect import DialectType
1919
from sqlglot.helper import ensure_list
20+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
2021

2122
from sqlmesh.core.config import DuckDBConnectionConfig
2223
from sqlmesh.core.context import Context
2324
from sqlmesh.core.engine_adapter import SparkEngineAdapter
2425
from sqlmesh.core.engine_adapter.base import EngineAdapter
26+
from sqlmesh.core.environment import EnvironmentNamingInfo
2527
from sqlmesh.core.macros import macro
2628
from sqlmesh.core.model import IncrementalByTimeRangeKind, SqlModel, model
2729
from sqlmesh.core.model.kind import OnDestructiveChange
@@ -126,7 +128,28 @@ def validate(
126128
*,
127129
env_name: t.Optional[str] = None,
128130
dialect: t.Optional[str] = None,
131+
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
129132
) -> t.Dict[t.Any, t.Any]:
133+
if (
134+
env_name
135+
and dialect
136+
and environment_naming_info
137+
and environment_naming_info.normalize_name
138+
):
139+
# if the environment_naming_info was configured to normalize names, then Snapshot.qualified_view_name.table_for_enviromnent()
140+
# returns schemas that contain the environment_name normalised for that engine
141+
#
142+
# in practice, this means "test_prod" becomes "TEST_PROD" on some engines so the final views are named like:
143+
#
144+
# "sushi__TEST_PROD"."waiter_as_customer_by_day"
145+
#
146+
# instead of:
147+
#
148+
# "sushi__test_prod"."waiter_as_customer_by_day"
149+
#
150+
# this matters for the reading the data back below to validate it
151+
env_name = normalize_identifiers(env_name, dialect=dialect).name
152+
130153
"""
131154
Both start and end are inclusive.
132155
"""

tests/core/engine_adapter/test_integration.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pandas as pd
1212
import pytest
1313
from sqlglot import exp, parse_one
14+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1415

1516
from sqlmesh import Config, Context, EngineAdapter
1617
from sqlmesh.cli.example_project import init_example_project
@@ -51,6 +52,7 @@ def __init__(
5152
self.gateway = gateway
5253
self._columns_to_types = columns_to_types
5354
self.test_id = random_id(short=True)
55+
self._context = None
5456

5557
@property
5658
def columns_to_types(self):
@@ -411,11 +413,14 @@ def create_context(
411413
self._context = Context(paths=".", config=config, gateway=self.gateway)
412414
return self._context
413415

414-
def cleanup(self, ctx: Context):
415-
schemas = []
416-
for _, model in ctx.models.items():
417-
schemas.append(model.schema_name)
418-
schemas.append(model.physical_schema)
416+
def cleanup(self, ctx: t.Optional[Context] = None):
417+
schemas = [self.schema(TEST_SCHEMA)]
418+
419+
ctx = ctx or self._context
420+
if ctx and ctx.models:
421+
for _, model in ctx.models.items():
422+
schemas.append(model.schema_name)
423+
schemas.append(model.physical_schema)
419424

420425
for schema_name in set(schemas):
421426
self.engine_adapter.drop_schema(
@@ -662,6 +667,14 @@ def ctx(engine_adapter, test_type, mark_gateway):
662667
return TestContext(test_type, engine_adapter, gateway)
663668

664669

670+
@pytest.fixture(autouse=True)
671+
def cleanup(ctx: TestContext):
672+
yield # run test
673+
674+
if ctx:
675+
ctx.cleanup()
676+
677+
665678
def test_catalog_operations(ctx: TestContext):
666679
if (
667680
ctx.engine_adapter.CATALOG_SUPPORT.is_unsupported
@@ -691,11 +704,11 @@ def test_catalog_operations(ctx: TestContext):
691704
ctx.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"')
692705
except Exception:
693706
pass
694-
current_catalog = ctx.engine_adapter.get_current_catalog()
707+
current_catalog = ctx.engine_adapter.get_current_catalog().lower()
695708
ctx.engine_adapter.set_current_catalog(catalog_name)
696-
assert ctx.engine_adapter.get_current_catalog() == catalog_name
709+
assert ctx.engine_adapter.get_current_catalog().lower() == catalog_name
697710
ctx.engine_adapter.set_current_catalog(current_catalog)
698-
assert ctx.engine_adapter.get_current_catalog() == current_catalog
711+
assert ctx.engine_adapter.get_current_catalog().lower() == current_catalog
699712

700713

701714
def test_drop_schema_catalog(ctx: TestContext, caplog):
@@ -782,21 +795,14 @@ def test_temp_table(ctx: TestContext):
782795
)
783796
table = ctx.table("example")
784797

785-
# The snowflake adapter persists the DataFrame to an intermediate table because we use the `write_pandas()` function from the Snowflake python library
786-
# Other adapters just use SQLGlot to convert the dataframe directly into a SELECT query
787-
expected_tables = 2 if ctx.dialect == "snowflake" and ctx.test_type == "df" else 1
788798
with ctx.engine_adapter.temp_table(ctx.input_data(input_data), table.sql()) as table_name:
789799
results = ctx.get_metadata_results()
790800
assert len(results.views) == 0
791-
assert len(results.tables) == expected_tables
801+
assert len(results.tables) == 1
792802
assert len(results.non_temp_tables) == 0
793803
assert len(results.materialized_views) == 0
794804
ctx.compare_with_current(table_name, input_data)
795805

796-
if ctx.dialect == "snowflake":
797-
# force the next query to create a new connection to prove temp tables have been dropped
798-
ctx.engine_adapter._connection_pool.close()
799-
800806
results = ctx.get_metadata_results()
801807
assert len(results.views) == len(results.tables) == len(results.non_temp_tables) == 0
802808

@@ -1735,6 +1741,14 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext):
17351741
personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()],
17361742
)
17371743
_, gateway = mark_gateway
1744+
1745+
# clear cache from prior runs
1746+
cache_dir = pathlib.Path("./examples/sushi/.cache")
1747+
if cache_dir.exists():
1748+
import shutil
1749+
1750+
shutil.rmtree(cache_dir)
1751+
17381752
context = Context(paths="./examples/sushi", config=config, gateway=gateway)
17391753

17401754
# clean up any leftover schemas from previous runs (requires context)
@@ -1769,7 +1783,7 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext):
17691783

17701784
context._models.update({cust_rev_by_day_key: cust_rev_by_day_model_tbl_props})
17711785

1772-
context.plan(
1786+
plan: Plan = context.plan(
17731787
environment="test_prod",
17741788
start=start,
17751789
end=end,
@@ -1785,6 +1799,7 @@ def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext):
17851799
yesterday(),
17861800
env_name="test_prod",
17871801
dialect=ctx.dialect,
1802+
environment_naming_info=plan.environment_naming_info,
17881803
)
17891804

17901805
# Ensure table and column comments were correctly registered with engine
@@ -1977,10 +1992,13 @@ def validate_no_comments(
19771992
# confirm physical temp table comments are not registered
19781993
validate_no_comments("sqlmesh__sushi", table_name_suffix="__temp", check_temp_tables=True)
19791994
# confirm view layer comments are not registered in non-PROD environment
1980-
validate_no_comments("sushi__test_prod", is_physical_layer=False)
1995+
env_name = "test_prod"
1996+
if plan.environment_naming_info and plan.environment_naming_info.normalize_name:
1997+
env_name = normalize_identifiers(env_name, dialect=ctx.dialect).name
1998+
validate_no_comments(f"sushi__{env_name}", is_physical_layer=False)
19811999

19822000
# Ensure that the plan has been applied successfully.
1983-
no_change_plan = context.plan(
2001+
no_change_plan: Plan = context.plan(
19842002
environment="test_dev",
19852003
start=start,
19862004
end=end,
@@ -2000,6 +2018,7 @@ def validate_no_comments(
20002018
yesterday(),
20012019
env_name="test_dev",
20022020
dialect=ctx.dialect,
2021+
environment_naming_info=no_change_plan.environment_naming_info,
20032022
)
20042023

20052024
# confirm view layer comments are registered in PROD
@@ -2051,7 +2070,7 @@ def test_init_project(ctx: TestContext, mark_gateway: t.Tuple[str, str], tmp_pat
20512070
assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 6
20522071

20532072
# make and validate unmodified dev environment
2054-
no_change_plan = context.plan(
2073+
no_change_plan: Plan = context.plan(
20552074
environment="test_dev",
20562075
skip_tests=True,
20572076
no_prompts=True,
@@ -2062,7 +2081,12 @@ def test_init_project(ctx: TestContext, mark_gateway: t.Tuple[str, str], tmp_pat
20622081

20632082
context.apply(no_change_plan)
20642083

2065-
dev_schema_results = ctx.get_metadata_results("sqlmesh_example__test_dev")
2084+
environment = no_change_plan.environment
2085+
first_snapshot = no_change_plan.environment.snapshots[0]
2086+
schema_name = first_snapshot.qualified_view_name.schema_for_environment(
2087+
environment, dialect=ctx.dialect
2088+
)
2089+
dev_schema_results = ctx.get_metadata_results(schema_name)
20662090
assert sorted(dev_schema_results.views) == [
20672091
"full_model",
20682092
"incremental_model",
@@ -2234,6 +2258,7 @@ def _mutate_config(current_gateway_name: str, config: Config):
22342258
connection.concurrent_tasks = 1
22352259

22362260
context = ctx.create_context(_mutate_config)
2261+
assert context.default_dialect == "duckdb"
22372262

22382263
schema = ctx.schema(TEST_SCHEMA)
22392264
seed_query = ctx.input_data(
@@ -2278,13 +2303,13 @@ def _mutate_config(current_gateway_name: str, config: Config):
22782303
try:
22792304
context.plan(auto_apply=True, no_prompts=True)
22802305

2281-
results = ctx.get_metadata_results(schema)
2306+
test_model = context.get_model(f"{schema}.test_model")
2307+
normalized_schema_name = test_model.fully_qualified_table.db
2308+
results = ctx.get_metadata_results(normalized_schema_name)
22822309
assert "test_model" in results.views
22832310

22842311
actual_df = (
2285-
ctx.get_current_data(f"{schema}.test_model")
2286-
.sort_values(by="event_date")
2287-
.reset_index(drop=True)
2312+
ctx.get_current_data(test_model.fqn).sort_values(by="event_date").reset_index(drop=True)
22882313
)
22892314
actual_df["event_date"] = actual_df["event_date"].astype(str)
22902315
assert actual_df.count()[0] == 3

0 commit comments

Comments
 (0)