diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala index 6151a43797..4622a694b7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanExec.scala @@ -85,12 +85,8 @@ case class CometScanExec( private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty - /** - * Send the driver-side metrics. Before calling this function, selectedPartitions has been - * initialized. See SPARK-26327 for more details. - */ - private def sendDriverMetrics(): Unit = { - driverMetrics.foreach(e => metrics(e._1).add(e._2)) + @transient private lazy val setDriverMetrics: Unit = { + driverMetrics.foreach(e => metrics(e._1).set(e._2)) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( sparkContext, @@ -98,6 +94,12 @@ case class CometScanExec( metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq) } + /** + * Send the driver-side metrics. Before calling this function, selectedPartitions has been + * initialized. See SPARK-26327 for more details. + */ + private def sendDriverMetrics(): Unit = setDriverMetrics + private def isDynamicPruningFilter(e: Expression): Boolean = e.find(_.isInstanceOf[PlanExpression[_]]).isDefined diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index aff1816265..60497126ff 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -2173,6 +2173,38 @@ class CometExecSuite extends CometTestBase { } } + test("Native_datafusion reports correct files and bytes scanned") { + withTempDir { dir => + val path = new java.io.File(dir, "test_metrics").getAbsolutePath + spark.range(100).repartition(2).write.mode("overwrite").parquet(path) + + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + val df = spark.read.parquet(path) + + // Trigger two different actions to ensure metrics are not duplicated + df.count() + df.collect() + + val scanNode = stripAQEPlan(df.queryExecution.executedPlan) + .collectFirst { + case n: org.apache.spark.sql.comet.CometNativeScanExec => n + case n: org.apache.spark.sql.comet.CometScanExec => n + } + .getOrElse { + fail( + s"Comet scan node not found in the physical plan. Plan: \n${df.queryExecution.executedPlan}") + } + + val numFiles = scanNode.metrics("numFiles").value + assert( + numFiles == 2, + s"Expected exactly 2 files to be scanned, but got metrics reporting $numFiles") + } + } + } + } case class BucketedTableTestSpec(