Skip to content

Commit 85f1d76

Browse files
committed
Remove time travel test for cloud engines, handle pyspark DFs in dbx
1 parent a000962 commit 85f1d76

File tree

2 files changed

+72
-52
lines changed

2 files changed

+72
-52
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SourceQuery,
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
17+
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
1718
from sqlmesh.core.node import IntervalUnit
1819
from sqlmesh.core.schema_diff import NestedSupport
1920
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
@@ -379,38 +380,59 @@ def _record_execution_stats(
379380
except:
380381
return
381382

382-
history = self.cursor.fetchall_arrow()
383-
if history.num_rows:
384-
history_df = history.to_pandas()
385-
write_df = history_df[history_df["operation"] == "WRITE"]
386-
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
387-
if not write_df.empty:
388-
metrics = write_df["operationMetrics"][0]
389-
if metrics:
390-
rowcount = None
391-
rowcount_str = [
392-
metric[1] for metric in metrics if metric[0] == "numOutputRows"
393-
]
394-
if rowcount_str:
395-
try:
396-
rowcount = int(rowcount_str[0])
397-
except (TypeError, ValueError):
398-
pass
399-
400-
bytes_processed = None
401-
bytes_str = [
402-
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
403-
]
404-
if bytes_str:
405-
try:
406-
bytes_processed = int(bytes_str[0])
407-
except (TypeError, ValueError):
408-
pass
409-
410-
if rowcount is not None or bytes_processed is not None:
411-
# if no rows were written, df contains 0 for bytes but no value for rows
412-
rowcount = (
413-
0 if rowcount is None and bytes_processed is not None else rowcount
414-
)
415-
416-
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)
383+
history = (
384+
self.cursor.fetchdf()
385+
if isinstance(self.cursor, SparkSessionCursor)
386+
else self.cursor.fetchall_arrow()
387+
)
388+
if history is not None:
389+
from pandas import DataFrame as PandasDataFrame
390+
from pyspark.sql import DataFrame as PySparkDataFrame
391+
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
392+
393+
history_df = None
394+
if isinstance(history, PandasDataFrame):
395+
history_df = history
396+
elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)):
397+
history_df = history.toPandas()
398+
else:
399+
# arrow table
400+
history_df = history.to_pandas()
401+
402+
if history_df is not None and not history_df.empty:
403+
write_df = history_df[history_df["operation"] == "WRITE"]
404+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
405+
if not write_df.empty:
406+
metrics = write_df["operationMetrics"][0]
407+
if metrics:
408+
rowcount = None
409+
rowcount_str = [
410+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
411+
]
412+
if rowcount_str:
413+
try:
414+
rowcount = int(rowcount_str[0])
415+
except (TypeError, ValueError):
416+
pass
417+
418+
bytes_processed = None
419+
bytes_str = [
420+
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
421+
]
422+
if bytes_str:
423+
try:
424+
bytes_processed = int(bytes_str[0])
425+
except (TypeError, ValueError):
426+
pass
427+
428+
if rowcount is not None or bytes_processed is not None:
429+
# if no rows were written, df contains 0 for bytes but no value for rows
430+
rowcount = (
431+
0
432+
if rowcount is None and bytes_processed is not None
433+
else rowcount
434+
)
435+
436+
QueryExecutionTracker.record_execution(
437+
sql, rowcount, bytes_processed
438+
)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,23 +2470,21 @@ def capture_execution_stats(
24702470
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24712471

24722472
# run that loads 0 rows in incremental model
2473-
actual_execution_stats = {}
2474-
with patch.object(
2475-
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2476-
):
2477-
with time_machine.travel(date.today() + timedelta(days=1)):
2478-
context.run()
2479-
2480-
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2481-
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2482-
# snowflake doesn't track rows for CTAS
2483-
assert actual_execution_stats["full_model"].total_rows_processed == (
2484-
None if ctx.mark.startswith("snowflake") else 3
2485-
)
2486-
2487-
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2488-
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2489-
assert actual_execution_stats["full_model"].total_bytes_processed is not None
2473+
# - some cloud DBs error because time travel messes up token expiration
2474+
if not ctx.is_remote:
2475+
actual_execution_stats = {}
2476+
with patch.object(
2477+
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2478+
):
2479+
with time_machine.travel(date.today() + timedelta(days=1)):
2480+
context.run()
2481+
2482+
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2483+
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2484+
# snowflake doesn't track rows for CTAS
2485+
assert actual_execution_stats["full_model"].total_rows_processed == (
2486+
None if ctx.mark.startswith("snowflake") else 3
2487+
)
24902488

24912489
# make and validate unmodified dev environment
24922490
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)