Skip to content

Commit e68fd6c

Browse files
committed
fix: bigquery snowflake source columns support
1 parent 34dc9fd commit e68fd6c

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,18 @@ def _df_to_source_queries(
169169
)
170170

171171
def query_factory() -> Query:
172-
if bigframes_pd and isinstance(df, bigframes_pd.DataFrame):
173-
df.to_gbq(
172+
ordered_df = df[list(source_columns_to_types)]
173+
if bigframes_pd and isinstance(ordered_df, bigframes_pd.DataFrame):
174+
ordered_df.to_gbq(
174175
f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}",
175176
if_exists="replace",
176177
)
177178
elif not self.table_exists(temp_table):
178179
# Make mypy happy
179-
assert isinstance(df, pd.DataFrame)
180+
assert isinstance(ordered_df, pd.DataFrame)
180181
self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False)
181182
result = self.__load_pandas_to_table(
182-
temp_bq_table, df, source_columns_to_types, replace=False
183+
temp_bq_table, ordered_df, source_columns_to_types, replace=False
183184
)
184185
if result.errors:
185186
raise SQLMeshError(result.errors)

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ def query_factory() -> Query:
354354
else None
355355
)
356356

357+
ordered_df = df[list(source_columns_to_types)]
358+
357359
if is_snowpark_dataframe:
358360
temp_table.set("catalog", database)
359361

@@ -362,11 +364,11 @@ def query_factory() -> Query:
362364
# then they will be quoted already. But if the Snowpark dataframe was created manually by the user, then the
363365
# columns may not be quoted
364366
columns_already_quoted = all(
365-
col.startswith('"') and col.endswith('"') for col in df.columns
367+
col.startswith('"') and col.endswith('"') for col in ordered_df.columns
366368
)
367-
local_df = df
369+
local_df = ordered_df
368370
if not columns_already_quoted:
369-
local_df = df.rename(
371+
local_df = ordered_df.rename(
370372
{
371373
col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True)
372374
for col in source_columns_to_types
@@ -375,7 +377,7 @@ def query_factory() -> Query:
375377
local_df.createOrReplaceTempView(
376378
temp_table.sql(dialect=self.dialect, identify=True)
377379
) # type: ignore
378-
elif isinstance(df, pd.DataFrame):
380+
elif isinstance(ordered_df, pd.DataFrame):
379381
from snowflake.connector.pandas_tools import write_pandas
380382

381383
# Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034
@@ -388,16 +390,16 @@ def query_factory() -> Query:
388390

389391
# See: https://stackoverflow.com/a/75627721
390392
for column, kind in source_columns_to_types.items():
391-
if is_datetime64_any_dtype(df.dtypes[column]):
393+
if is_datetime64_any_dtype(ordered_df.dtypes[column]):
392394
if kind.is_type("date"): # type: ignore
393-
df[column] = pd.to_datetime(df[column]).dt.date # type: ignore
394-
elif getattr(df.dtypes[column], "tz", None) is not None: # type: ignore
395-
df[column] = pd.to_datetime(df[column]).dt.strftime(
395+
ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.date # type: ignore
396+
elif getattr(ordered_df.dtypes[column], "tz", None) is not None: # type: ignore
397+
ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime(
396398
"%Y-%m-%d %H:%M:%S.%f%z"
397399
) # type: ignore
398400
# https://github.com/snowflakedb/snowflake-connector-python/issues/1677
399401
else: # type: ignore
400-
df[column] = pd.to_datetime(df[column]).dt.strftime(
402+
ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime(
401403
"%Y-%m-%d %H:%M:%S.%f"
402404
) # type: ignore
403405

@@ -407,7 +409,7 @@ def query_factory() -> Query:
407409

408410
write_pandas(
409411
self._connection_pool.get(),
410-
df,
412+
ordered_df,
411413
temp_table.name,
412414
schema=temp_table.db or None,
413415
database=database.sql(dialect=self.dialect) if database else None,
@@ -417,7 +419,7 @@ def query_factory() -> Query:
417419
)
418420
else:
419421
raise SQLMeshError(
420-
f"Unknown dataframe type: {type(df)} for {target_table}. Expecting pandas or snowpark."
422+
f"Unknown dataframe type: {type(ordered_df)} for {target_table}. Expecting pandas or snowpark."
421423
)
422424

423425
return exp.select(

tests/core/engine_adapter/integration/config.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,6 @@ gateways:
154154
inttest_bigquery:
155155
connection:
156156
type: bigquery
157-
method: service-account
158-
keyfile: {{ env_var('BIGQUERY_KEYFILE') }}
159-
check_import: false
160157
state_connection:
161158
type: duckdb
162159

0 commit comments

Comments
 (0)