Skip to content

Commit 0ac5726

Browse files
authored
Fix: Stop leaking temporary tables in the Snowflake adapter (#2865)
1 parent 0057719 commit 0ac5726

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ def query_factory() -> Query:
230230
df[column] = pd.to_datetime(df[column]).dt.strftime(
231231
"%Y-%m-%d %H:%M:%S.%f"
232232
) # type: ignore
233-
self.create_table(temp_table, columns_to_types)
233+
234+
# create the table first using our usual method ensure the column datatypes match what we parsed with sqlglot
235+
# otherwise we would be trusting `write_pandas()` from the snowflake lib to do this correctly
236+
self.create_table(temp_table, columns_to_types, table_kind="TEMPORARY TABLE")
234237

235238
write_pandas(
236239
self._connection_pool.get(),
@@ -240,7 +243,7 @@ def query_factory() -> Query:
240243
database=temp_table.catalog or None,
241244
chunk_size=self.DEFAULT_BATCH_SIZE,
242245
overwrite=True,
243-
table_type="temp",
246+
table_type="temp", # if you dont have this, it will convert the table we created above into a normal table and it wont get dropped when the session ends
244247
)
245248
else:
246249
raise SQLMeshError(
@@ -292,6 +295,7 @@ def _get_data_objects(
292295
)
293296
.when(exp.column("TABLE_TYPE").eq("BASE TABLE"), exp.Literal.string("TABLE"))
294297
.when(exp.column("TABLE_TYPE").eq("TEMPORARY TABLE"), exp.Literal.string("TABLE"))
298+
.when(exp.column("TABLE_TYPE").eq("LOCAL TEMPORARY"), exp.Literal.string("TABLE"))
295299
.when(exp.column("TABLE_TYPE").eq("EXTERNAL TABLE"), exp.Literal.string("TABLE"))
296300
.when(exp.column("TABLE_TYPE").eq("EVENT TABLE"), exp.Literal.string("TABLE"))
297301
.when(exp.column("TABLE_TYPE").eq("VIEW"), exp.Literal.string("VIEW"))

tests/core/engine_adapter/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,13 +781,22 @@ def test_temp_table(ctx: TestContext):
781781
]
782782
)
783783
table = ctx.table("example")
784+
785+
# The snowflake adapter persists the DataFrame to an intermediate table because we use the `write_pandas()` function from the Snowflake python library
786+
# Other adapters just use SQLGlot to convert the dataframe directly into a SELECT query
787+
expected_tables = 2 if ctx.dialect == "snowflake" and ctx.test_type == "df" else 1
784788
with ctx.engine_adapter.temp_table(ctx.input_data(input_data), table.sql()) as table_name:
785789
results = ctx.get_metadata_results()
786790
assert len(results.views) == 0
787-
assert len(results.tables) == 1
791+
assert len(results.tables) == expected_tables
788792
assert len(results.non_temp_tables) == 0
789793
assert len(results.materialized_views) == 0
790794
ctx.compare_with_current(table_name, input_data)
795+
796+
if ctx.dialect == "snowflake":
797+
# force the next query to create a new connection to prove temp tables have been dropped
798+
ctx.engine_adapter._connection_pool.close()
799+
791800
results = ctx.get_metadata_results()
792801
assert len(results.views) == len(results.tables) == len(results.non_temp_tables) == 0
793802

0 commit comments

Comments
 (0)