diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 78ea0f0168..69baf97204 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -226,6 +226,19 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_CACHE_SERIALIZER_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.cache.serializer.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, Comet installs a CachedBatchSerializer that stores Spark's in-memory " + + "table cache as compressed Arrow IPC. Repeated scans of cached data are then read " + + "natively without a per-read conversion. Schemas Comet cannot handle transparently " + + "fall back to Spark's default cache serializer. This sets " + + "spark.sql.cache.serializer for the session unless that property is already set " + + "to a non-default value. Disabled by default.") + .booleanConf + .createWithDefault(false) + val COMET_EXEC_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.enabled") .category(CATEGORY_EXEC) .doc( diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 7290ab436a..26d77c4f94 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf -import org.apache.comet.CometConf.{COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} +import org.apache.comet.CometConf.{COMET_CACHE_SERIALIZER_ENABLED, COMET_METRICS_ENABLED, COMET_ONHEAP_ENABLED} import org.apache.comet.CometSparkSessionExtensions /** @@ -57,6 +57,9 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl // register CometSparkSessionExtensions if it isn't already registered CometDriverPlugin.registerCometSessionExtension(sc.conf) + // Install the Comet cache serializer if requested + CometDriverPlugin.setCacheSerializerIfEnabled(sc.conf) + // Register Comet metrics CometDriverPlugin.registerCometMetrics(sc) @@ -104,6 +107,26 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl } object CometDriverPlugin extends Logging { + private[spark] val COMET_CACHE_SERIALIZER = + "org.apache.spark.sql.comet.CometCachedBatchSerializer" + + /** + * If the Comet cache serializer is enabled, install it as Spark's cache serializer. This is a + * static SQL config, so it must be set on the SparkConf before the session is created. A + * user-provided non-default serializer is respected and not overridden. + */ + private[spark] def setCacheSerializerIfEnabled(conf: SparkConf): Unit = { + if (conf.getBoolean(COMET_CACHE_SERIALIZER_ENABLED.key, defaultValue = false)) { + val key = StaticSQLConf.SPARK_CACHE_SERIALIZER.key + val default = StaticSQLConf.SPARK_CACHE_SERIALIZER.defaultValueString + val existing = conf.get(key, default) + if (existing == default) { + logInfo(s"Setting $key=$COMET_CACHE_SERIALIZER") + conf.set(key, COMET_CACHE_SERIALIZER) + } + } + } + def registerCometMetrics(sc: SparkContext): Unit = { if (sc.getConf.getBoolean( COMET_METRICS_ENABLED.key, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala new file mode 100644 index 0000000000..e0eec845a4 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometCachedBatchSerializer.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import java.nio.ByteBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow} +import org.apache.spark.sql.columnar.{CachedBatch, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} +import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.util.{Utils => CometUtils} +import org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.ByteArray +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.io.ChunkedByteBuffer + +import org.apache.comet.CometConf + +/** + * A cached batch holding one compressed Arrow IPC message plus Spark-format column stats. + * + * @param numRows + * number of rows in this batch + * @param bytes + * compressed Arrow IPC bytes for a single record batch + * @param stats + * InternalRow laid out as ColumnStatisticsSchema expects: per column [lowerBound, upperBound, + * nullCount, count, sizeInBytes] + */ +case class CometCachedBatch(numRows: Int, bytes: Array[Byte], stats: InternalRow) + extends SimpleMetricsCachedBatch { + // Used by InMemoryRelation to estimate the cached relation size; must reflect real bytes. + override def sizeInBytes: Long = bytes.length.toLong +} + +/** + * Accumulates per-column min/max/null/count for a set of rows and emits the stats InternalRow in + * the exact layout Spark's ColumnStatisticsSchema / SimpleMetricsCachedBatchSerializer expects. + * + * For column data types where a total ordering is not implemented here, the lower/upper bounds + * are left null. Null bounds mean "cannot prune" and are always correct (this is how Spark itself + * encodes unknown stats). + */ +class CometCacheColumnStats(attributes: Seq[Attribute]) { + private val numCols = attributes.length + private val lower = new Array[Any](numCols) + private val upper = new Array[Any](numCols) + private val nulls = new Array[Int](numCols) + private var rowCount = 0 + private val tracksBounds: Array[Boolean] = attributes.map(a => ordered(a.dataType)).toArray + + /** Update column `ordinal` with one value. `value` is in Catalyst internal form (or null). */ + def update(ordinal: Int, dt: DataType, isNull: Boolean, value: Any): Unit = { + if (isNull) { + nulls(ordinal) += 1 + return + } + if (!tracksBounds(ordinal)) return // leave bounds null for unsupported-stat types + if (lower(ordinal) == null || compare(dt, value, lower(ordinal)) < 0) lower(ordinal) = value + if (upper(ordinal) == null || compare(dt, value, upper(ordinal)) > 0) upper(ordinal) = value + } + + /** + * Sets the total row count for this batch (the `count` stat field). Must be called before + * `toInternalRow`; otherwise `count` stays 0 and predicates like IsNotNull could incorrectly + * prune a non-empty batch. + */ + def setRowCount(n: Int): Unit = rowCount = n + + def toInternalRow: InternalRow = { + val values = new Array[Any](numCols * 5) + var i = 0 + while (i < numCols) { + val base = i * 5 + values(base) = lower(i) // lowerBound (column data type or null) + values(base + 1) = upper(i) // upperBound + values(base + 2) = nulls(i) // nullCount (Int) + values(base + 3) = rowCount // count (Int) + values(base + 4) = 0L // sizeInBytes (Long); not used by buildFilter + i += 1 + } + new GenericInternalRow(values) + } + + private def ordered(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | DateType | TimestampType | TimestampNTZType => + true + case _ => false + } + + private def compare(dt: DataType, x: Any, y: Any): Int = dt match { + case BooleanType => + java.lang.Boolean.compare(x.asInstanceOf[Boolean], y.asInstanceOf[Boolean]) + case ByteType => java.lang.Byte.compare(x.asInstanceOf[Byte], y.asInstanceOf[Byte]) + case ShortType => java.lang.Short.compare(x.asInstanceOf[Short], y.asInstanceOf[Short]) + case IntegerType | DateType => + java.lang.Integer.compare(x.asInstanceOf[Int], y.asInstanceOf[Int]) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.compare(x.asInstanceOf[Long], y.asInstanceOf[Long]) + case FloatType => java.lang.Float.compare(x.asInstanceOf[Float], y.asInstanceOf[Float]) + case DoubleType => java.lang.Double.compare(x.asInstanceOf[Double], y.asInstanceOf[Double]) + case _: DecimalType => + x.asInstanceOf[org.apache.spark.sql.types.Decimal] + .compare(y.asInstanceOf[org.apache.spark.sql.types.Decimal]) + case StringType => + ByteArray.compareBinary( + x.asInstanceOf[UTF8String].getBytes, + y.asInstanceOf[UTF8String].getBytes) + case other => throw new IllegalStateException(s"compare called for unordered type $other") + } +} + +class CometCachedBatchSerializer extends SimpleMetricsCachedBatchSerializer { + + // Delegate target for schemas Comet does not handle. Serializable (no-arg constructor). + private val fallback = new DefaultCachedBatchSerializer + + /** Comet handles flat schemas of the data types its Arrow conversion supports. */ + private def isCometSchema(dataTypes: Seq[DataType]): Boolean = + dataTypes.forall(isCometType) + + private def isCometType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | StringType | BinaryType | DateType | TimestampType | TimestampNTZType => + true + // Nested/complex types are out of scope for v1; delegate to the default serializer. + case _ => false + } + + // Force the row build path for Comet schemas (single code path for encode + stats); delegate + // otherwise so the default serializer's columnar-input optimization still applies. + override def supportsColumnarInput(schema: Seq[Attribute]): Boolean = + if (isCometSchema(schema.map(_.dataType))) false else fallback.supportsColumnarInput(schema) + + override def supportsColumnarOutput(schema: StructType): Boolean = + if (isCometSchema(schema.map(_.dataType))) true else fallback.supportsColumnarOutput(schema) + + // Let Spark use generic ColumnVector access; our columns are heterogeneous CometVector subtypes. + override def vectorTypes(attributes: Seq[Attribute], conf: SQLConf): Option[Seq[String]] = None + + private def toStructType(attrs: Seq[Attribute]): StructType = + StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + // Compute stats from an already-built Arrow ColumnarBatch (columns are CometVector). + private def computeStats(batch: ColumnarBatch, attrs: Seq[Attribute]): InternalRow = { + val acc = new CometCacheColumnStats(attrs) + val numRows = batch.numRows() + var c = 0 + while (c < attrs.length) { + val dt = attrs(c).dataType + val col = batch.column(c) + var r = 0 + while (r < numRows) { + if (col.isNullAt(r)) { + acc.update(c, dt, isNull = true, null) + } else { + acc.update(c, dt, isNull = false, readValue(col, dt, r)) + } + r += 1 + } + c += 1 + } + acc.setRowCount(numRows) + acc.toInternalRow + } + + // Read one value in Catalyst internal form from a ColumnVector. + private def readValue(col: ColumnVector, dt: DataType, r: Int): Any = dt match { + case BooleanType => col.getBoolean(r) + case ByteType => col.getByte(r) + case ShortType => col.getShort(r) + case IntegerType | DateType => col.getInt(r) + case LongType | TimestampType | TimestampNTZType => col.getLong(r) + case FloatType => col.getFloat(r) + case DoubleType => col.getDouble(r) + case d: DecimalType => col.getDecimal(r, d.precision, d.scale) + case StringType => col.getUTF8String(r).copy() + case _ => null // BinaryType etc.: no stats bounds + } + + // INVARIANT: compute stats BEFORE calling this. serializeBatches internally clears the + // VectorSchemaRoot wrapping the batch's field vectors, so the batch must not be read after + // this call. The row/columnar Arrow iterators reset those vectors before producing the next + // batch, so the clear is safe as long as we never touch this batch again. + private def encodeBytes(batch: ColumnarBatch): Array[Byte] = { + val it = CometUtils.serializeBatches(Iterator.single(batch)) + val (_, cbb) = it.next() + cbb.toArray + } + + private def encode( + arrowBatches: Iterator[ColumnarBatch], + attrs: Seq[Attribute]): Iterator[CachedBatch] = + arrowBatches.map { batch => + val stats = computeStats(batch, attrs) + val bytes = encodeBytes(batch) + CometCachedBatch(batch.numRows(), bytes, stats) + } + + override def convertInternalRowToCachedBatch( + input: RDD[InternalRow], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + if (!isCometSchema(schema.map(_.dataType))) { + return fallback.convertInternalRowToCachedBatch(input, schema, storageLevel, conf) + } + val structType = toStructType(schema) + val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf).toLong + input.mapPartitions { rowIter => + val ctx = TaskContext.get() + val arrowBatches = + CometArrowConverters.rowToArrowBatchIter(rowIter, structType, maxRecords, "UTC", ctx) + encode(arrowBatches, schema) + } + } + + override def convertColumnarBatchToCachedBatch( + input: RDD[ColumnarBatch], + schema: Seq[Attribute], + storageLevel: StorageLevel, + conf: SQLConf): RDD[CachedBatch] = { + if (!isCometSchema(schema.map(_.dataType))) { + return fallback.convertColumnarBatchToCachedBatch(input, schema, storageLevel, conf) + } + // This branch is never reached for Comet schemas: supportsColumnarInput returns false for + // them, so Spark always takes the row path above. It is only reachable for delegated + // (non-Comet) schemas that somehow bypass the fallback guard, and is implemented defensively. + val structType = toStructType(schema) + val maxRecords = CometConf.COMET_BATCH_SIZE.get(conf) + input.mapPartitions { batchIter => + val ctx = TaskContext.get() + val arrowBatches = batchIter.flatMap { b => + CometArrowConverters.columnarBatchToArrowBatchIter(b, structType, maxRecords, "UTC", ctx) + } + encode(arrowBatches, schema) + } + } + + // Map selected attributes to their column indices within cacheAttributes by exprId. + private def selectedIndices( + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute]): Array[Int] = { + val byId = cacheAttributes.map(_.exprId).zipWithIndex.toMap + selectedAttributes.map { a => + byId.getOrElse( + a.exprId, + throw new IllegalStateException( + s"Selected attribute $a (exprId ${a.exprId}) not found in cached attributes")) + }.toArray + } + + // True when `indices` selects every column in order: length == numCols and indices(i) == i. + private def isIdentityProjection(indices: Array[Int], numCols: Int): Boolean = + indices.length == numCols && indices.indices.forall(i => indices(i) == i) + + // Decode one CometCachedBatch into a ColumnarBatch projected to the selected columns. + private def decodeOne(b: CometCachedBatch, indices: Array[Int]): Iterator[ColumnarBatch] = { + val chunked = new ChunkedByteBuffer(ByteBuffer.wrap(b.bytes)) + CometUtils.decodeBatches(chunked, "CometCachedBatch").map { full => + if (isIdentityProjection(indices, full.numCols())) { + full + } else { + val cols = indices.map(full.column) + new ColumnarBatch(cols, full.numRows()) + } + } + } + + // Version-safe conversion of a ColumnarBatch's java row iterator to copied Scala InternalRows. + private def rowsOf(batch: ColumnarBatch): Iterator[InternalRow] = { + val it = batch.rowIterator() + new Iterator[InternalRow] { + override def hasNext: Boolean = it.hasNext + override def next(): InternalRow = it.next().copy() + } + } + + private def decodeCometBatches( + input: RDD[CachedBatch], + indices: Array[Int]): RDD[ColumnarBatch] = + input.mapPartitions { batchIter => + batchIter.flatMap { + case b: CometCachedBatch => decodeOne(b, indices) + case other => + throw new IllegalStateException( + s"Expected CometCachedBatch but got ${other.getClass.getName}") + } + } + + override def convertCachedBatchToColumnarBatch( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[ColumnarBatch] = { + if (!isCometSchema(cacheAttributes.map(_.dataType))) { + return fallback.convertCachedBatchToColumnarBatch( + input, + cacheAttributes, + selectedAttributes, + conf) + } + decodeCometBatches(input, selectedIndices(cacheAttributes, selectedAttributes)) + } + + override def convertCachedBatchToInternalRow( + input: RDD[CachedBatch], + cacheAttributes: Seq[Attribute], + selectedAttributes: Seq[Attribute], + conf: SQLConf): RDD[InternalRow] = { + if (!isCometSchema(cacheAttributes.map(_.dataType))) { + return fallback.convertCachedBatchToInternalRow( + input, + cacheAttributes, + selectedAttributes, + conf) + } + decodeCometBatches(input, selectedIndices(cacheAttributes, selectedAttributes)) + .flatMap(b => rowsOf(b)) + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index efe6a97d40..f03ba18d36 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.{CometConf, DataTypeSupport} import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.operator.CometSink +import org.apache.comet.vector.CometVector case class CometSparkToColumnarExec(child: SparkPlan) extends RowToColumnarTransition @@ -67,7 +68,10 @@ case class CometSparkToColumnarExec(child: SparkPlan) "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), "conversionTime" -> SQLMetrics.createNanoTimingMetric( sparkContext, - "time converting Spark batches to Arrow batches")) + "time converting Spark batches to Arrow batches"), + "numPassthroughBatches" -> SQLMetrics.createMetric( + sparkContext, + "number of passthrough Arrow batches")) // The conversion happens in next(), so wrap the call to measure time spent. private def createTimingIter( @@ -96,6 +100,7 @@ case class CometSparkToColumnarExec(child: SparkPlan) val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val conversionTime = longMetric("conversionTime") + val numPassthroughBatches = longMetric("numPassthroughBatches") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) // Use UTC for Arrow schema timezone to match the native side, which always // deserializes Timestamp as Timestamp(Microsecond, Some("UTC")). Spark's internal @@ -111,13 +116,19 @@ case class CometSparkToColumnarExec(child: SparkPlan) .mapPartitionsInternal { sparkBatches => val arrowBatches = sparkBatches.flatMap { sparkBatch => - val context = TaskContext.get() - CometArrowConverters.columnarBatchToArrowBatchIter( - sparkBatch, - schema, - maxRecordsPerBatch, - timeZoneId, - context) + if (isAllCometVectors(sparkBatch)) { + // Already Arrow (e.g. from CometCachedBatchSerializer): pass through, no copy. + numPassthroughBatches += 1 + Iterator.single(sparkBatch) + } else { + val context = TaskContext.get() + CometArrowConverters.columnarBatchToArrowBatchIter( + sparkBatch, + schema, + maxRecordsPerBatch, + timeZoneId, + context) + } } createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) } @@ -141,6 +152,16 @@ case class CometSparkToColumnarExec(child: SparkPlan) override protected def withNewChildInternal(newChild: SparkPlan): CometSparkToColumnarExec = copy(child = newChild) + private def isAllCometVectors(batch: ColumnarBatch): Boolean = { + if (batch.numCols() == 0) return false + var i = 0 + while (i < batch.numCols()) { + if (!batch.column(i).isInstanceOf[CometVector]) return false + i += 1 + } + true + } + } object CometSparkToColumnarExec extends CometSink[SparkPlan] with DataTypeSupport { diff --git a/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala new file mode 100644 index 0000000000..e3c75f4e68 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCachedBatchSerializerSuite.scala @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.time.LocalDateTime + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.comet.{CometCacheColumnStats, CometCachedBatch, CometCachedBatchSerializer, CometSparkToColumnarExec} +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometConf + +class CometCachedBatchSerializerSuite extends CometTestBase { + + import testImplicits._ + + // spark.sql.cache.serializer is a STATIC SQL config (cannot be set via withSQLConf at + // runtime), so it must be set at session creation. This makes the whole suite use the Comet + // cache serializer; the pure-unit tests construct a serializer directly and are unaffected. + override protected def sparkConf: org.apache.spark.SparkConf = { + super.sparkConf + .set("spark.sql.cache.serializer", "org.apache.spark.sql.comet.CometCachedBatchSerializer") + } + + test("stats row has 5 fields per column in cachedAttributes order") { + val a = AttributeReference("a", IntegerType, nullable = true)() + val b = AttributeReference("b", StringType, nullable = true)() + val acc = new CometCacheColumnStats(Seq(a, b)) + // column 0: values 5, null, 3 ; column 1: "y", "a", null + acc.update(0, IntegerType, isNull = false, 5) + acc.update(0, IntegerType, isNull = true, null) + acc.update(0, IntegerType, isNull = false, 3) + acc.update(1, StringType, isNull = false, UTF8String.fromString("y")) + acc.update(1, StringType, isNull = false, UTF8String.fromString("a")) + acc.update(1, StringType, isNull = true, null) + acc.setRowCount(3) + val stats = acc.toInternalRow + + assert(stats.numFields == 10) // 5 fields * 2 columns + // column 0: [lower=3, upper=5, nullCount=1, count=3, sizeInBytes=0] + assert(stats.getInt(0) == 3) + assert(stats.getInt(1) == 5) + assert(stats.getInt(2) == 1) + assert(stats.getInt(3) == 3) + // column 1: [lower="a", upper="y", nullCount=1, count=3, sizeInBytes=0] + assert(stats.getUTF8String(5) == UTF8String.fromString("a")) + assert(stats.getUTF8String(6) == UTF8String.fromString("y")) + assert(stats.getInt(7) == 1) + assert(stats.getInt(8) == 3) + // sizeInBytes stat slots (positions 4 and 9) are 0L; they are not used by buildFilter + assert(stats.getLong(4) == 0L) + assert(stats.getLong(9) == 0L) + // CometCachedBatch.sizeInBytes reflects the IPC byte length + val cb = CometCachedBatch(numRows = 3, bytes = Array[Byte](1, 2, 3, 4, 5), stats = stats) + assert(cb.sizeInBytes == 5L) + assert(cb.numRows == 3) + } + + test("supportsColumnarOutput: true for flat supported schema, delegated for nested") { + val ser = new CometCachedBatchSerializer + val flat = StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))) + val nested = StructType(Seq(StructField("a", ArrayType(IntegerType)))) + assert(ser.supportsColumnarOutput(flat)) + // nested delegates to DefaultCachedBatchSerializer, which does not support columnar output + assert(!ser.supportsColumnarOutput(nested)) + } + + test("build path produces one CometCachedBatch per Arrow batch with stats") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "100") { + val ser = new CometCachedBatchSerializer + // coalesce(1) makes the batch chunking deterministic: 250 rows / 100 = 3 batches + val df = spark.range(250).coalesce(1).selectExpr("id", "cast(id as string) as s") + val attrs = df.queryExecution.analyzed.output + val rdd = df.queryExecution.toRdd + val cached = ser + .convertInternalRowToCachedBatch( + rdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + .collect() + assert(cached.length == 3) + assert(cached.forall(_.isInstanceOf[CometCachedBatch])) + assert(cached.map(_.numRows).sum == 250) + cached.foreach { b => + assert(b.sizeInBytes > 0) + assert(b.asInstanceOf[CometCachedBatch].stats.numFields == attrs.length * 5) + } + // column 0 is the bigint id; verify real (non-null) stats were computed + val statRows = cached.map(_.asInstanceOf[CometCachedBatch].stats) + // lowerBound of col 0 lives at field 0 (LongType); min across batches must be 0 + assert(statRows.map(_.getLong(0)).min == 0L) + // nullCount of col 0 lives at field 2; range() has no nulls + assert(statRows.forall(_.getInt(2) == 0)) + } + } + + test("round-trip: build then decode all columns matches input") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "64") { + val ser = new CometCachedBatchSerializer + val df = spark.range(200).coalesce(1).selectExpr("id", "cast(id * 2 as string) as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser.convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + val decodedRows = ser + .convertCachedBatchToInternalRow(cached, attrs, attrs, spark.sessionState.conf) + .map(r => (r.getLong(0), r.getUTF8String(1).toString)) + .collect() + .toSet + val expected = (0 until 200).map(i => (i.toLong, (i * 2).toString)).toSet + assert(decodedRows == expected) + } + } + + test("read path prunes to selected columns") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "64") { + val ser = new CometCachedBatchSerializer + val df = spark.range(200).coalesce(1).selectExpr("id", "cast(id * 2 as string) as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser.convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + // select only the string column (index 1) + val onlyS = Seq(attrs(1)) + val pruned = ser + .convertCachedBatchToInternalRow(cached, attrs, onlyS, spark.sessionState.conf) + .map(_.getUTF8String(0).toString) + .collect() + .toSet + assert(pruned == (0 until 200).map(i => (i * 2).toString).toSet) + } + } + + test("columnar read path: full and pruned projection") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "64") { + val ser = new CometCachedBatchSerializer + val df = spark.range(100).coalesce(1).selectExpr("id", "cast(id * 2 as string) as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser.convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + + // Full projection (identity passthrough): 2 columns, values match. + val fullColCounts = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, attrs, spark.sessionState.conf) + .map(_.numCols()) + .collect() + assert(fullColCounts.forall(_ == 2)) + val fullVals = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, attrs, spark.sessionState.conf) + .mapPartitions { batches => + batches.flatMap { b => + val rows = new scala.collection.mutable.ArrayBuffer[(Long, String)] + var i = 0 + while (i < b.numRows()) { + rows += ((b.column(0).getLong(i), b.column(1).getUTF8String(i).toString)) + i += 1 + } + rows.iterator + } + } + .collect() + .toSet + assert(fullVals == (0 until 100).map(i => (i.toLong, (i * 2).toString)).toSet) + + // Pruned projection: only the string column (index 1) -> 1 column, correct values. + val onlyS = Seq(attrs(1)) + val prunedColCounts = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, onlyS, spark.sessionState.conf) + .map(_.numCols()) + .collect() + assert(prunedColCounts.forall(_ == 1)) + val prunedVals = + ser + .convertCachedBatchToColumnarBatch(cached, attrs, onlyS, spark.sessionState.conf) + .mapPartitions { batches => + batches.flatMap { b => + val rows = new scala.collection.mutable.ArrayBuffer[String] + var i = 0 + while (i < b.numRows()) { + rows += b.column(0).getUTF8String(i).toString + i += 1 + } + rows.iterator + } + } + .collect() + .toSet + assert(prunedVals == (0 until 100).map(i => (i * 2).toString).toSet) + } + } + + test("cached scan passes already-Arrow batches through CometSparkToColumnarExec") { + withSQLConf( + org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { // jvm shuffle keeps a CometSparkToColumnarExec in the plan over the cached scan + spark + .range(1000) + .selectExpr("id as key", "id % 8 as value") + .createOrReplaceTempView("comet_cache_c1") + spark.catalog.cacheTable("comet_cache_c1") + try { + // groupBy forces a CometSparkToColumnarExec to appear above the cached InMemoryTableScan. + val df = spark.sql("SELECT value, count(*) FROM comet_cache_c1 GROUP BY value") + val rows = df.collect() + assert(rows.length == 8) + val s2c = collectFirst(df.queryExecution.executedPlan) { + case s: CometSparkToColumnarExec => s + } + // CometSparkToColumnarExec must appear above the cached scan and must have taken the + // passthrough fast-path (batches already Arrow, no re-copy needed). + assert(s2c.isDefined, "expected CometSparkToColumnarExec in plan over cached scan") + assert(s2c.get.metrics("numPassthroughBatches").value > 0L) + } finally { + spark.catalog.uncacheTable("comet_cache_c1") + } + } + } + + test("cached query result matches uncached") { + val base = spark + .range(2000) + .selectExpr("id as k", "id % 10 as v", "cast(id as string) as s") + val expected = + base + .groupBy("v") + .count() + .orderBy("v") + .collect() + .toSeq + .map(r => (r.getLong(0), r.getLong(1))) + base.createOrReplaceTempView("comet_cache_t8") + spark.catalog.cacheTable("comet_cache_t8") + try { + val df = spark.sql("SELECT v, count(*) AS c FROM comet_cache_t8 GROUP BY v ORDER BY v") + checkSparkAnswer(df) + val actual = df.collect().toSeq.map(r => (r.getLong(0), r.getLong(1))) + assert(actual == expected) + } finally { + spark.catalog.uncacheTable("comet_cache_t8") + } + } + + test("filtered cached scan returns correct rows with stats pruning") { + spark.range(5000).selectExpr("id as k").createOrReplaceTempView("comet_cache_t8f") + spark.catalog.cacheTable("comet_cache_t8f") + try { + val df = spark.sql("SELECT k FROM comet_cache_t8f WHERE k >= 4990 ORDER BY k") + checkSparkAnswer(df) + val actual = df.collect().map(_.getLong(0)).toSeq + assert(actual == (4990L until 5000L).toSeq) + } finally { + spark.catalog.uncacheTable("comet_cache_t8f") + } + } + + test("cached table with MEMORY_AND_DISK round-trips") { + val cachedDf = spark + .range(3000) + .selectExpr("id as k", "cast(id as string) as s") + .persist(StorageLevel.MEMORY_AND_DISK) + try { + assert(cachedDf.count() == 3000) + checkSparkAnswer(cachedDf.filter("k % 2 = 0")) + } finally { + cachedDf.unpersist() + } + } + + test("array-typed cached relation delegates to default serializer and is correct") { + val df0 = spark.range(100).selectExpr("id as k", "array(id, id + 1) as a") + df0.createOrReplaceTempView("comet_cache_t8a") + spark.catalog.cacheTable("comet_cache_t8a") + try { + val df = spark.sql("SELECT k, a FROM comet_cache_t8a WHERE k < 5 ORDER BY k") + checkSparkAnswer(df) + assert(df.count() == 5) + } finally { + spark.catalog.uncacheTable("comet_cache_t8a") + } + } + + test("string column stats survive encode (no buffer use-after-free)") { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "100") { + val ser = new CometCachedBatchSerializer + // zero-padded so lexicographic order is well-defined and stable + val df = spark + .range(250) + .coalesce(1) + .selectExpr("id", "lpad(cast(id as string), 5, '0') as s") + val attrs = df.queryExecution.analyzed.output + val cached = ser + .convertInternalRowToCachedBatch( + df.queryExecution.toRdd, + attrs, + org.apache.spark.storage.StorageLevel.MEMORY_ONLY, + spark.sessionState.conf) + .collect() + assert(cached.length == 3) + // column 1 is the string column; its stats live at fields [5..9]: + // field 5 = lowerBound, field 6 = upperBound + cached.zipWithIndex.foreach { case (b, batchIdx) => + val stats = b.asInstanceOf[CometCachedBatch].stats + val lo = stats.getUTF8String(5).toString + val hi = stats.getUTF8String(6).toString + val start = batchIdx * 100 + val end = math.min(start + 100, 250) - 1 + assert( + lo == f"$start%05d", + s"batch $batchIdx lowerBound was '$lo', expected ${f"$start%05d"}") + assert( + hi == f"$end%05d", + s"batch $batchIdx upperBound was '$hi', expected ${f"$end%05d"}") + } + } + } + + test("filtered cached scan on a string column returns correct rows") { + spark + .range(2000) + .selectExpr("lpad(cast(id as string), 5, '0') as s") + .createOrReplaceTempView("comet_cache_str") + spark.catalog.cacheTable("comet_cache_str") + try { + val df = spark.sql("SELECT s FROM comet_cache_str WHERE s = '01999'") + checkSparkAnswer(df) + val rows = df.collect().map(_.getString(0)).toSeq + assert(rows == Seq("01999")) + } finally { + spark.catalog.uncacheTable("comet_cache_str") + } + } + + test("timestamp_ntz cached scan is correct") { + // A Seq[LocalDateTime] maps to TimestampNTZType, which the Comet serializer supports. + val data = (0 until 50).map(i => (i.toLong, LocalDateTime.of(2020, 1, 1, 0, 0, i % 60))) + val df0 = data.toDF("id", "ts") + // Expected values from the uncached DataFrame (before caching). + val expected = df0 + .where("id < 10") + .orderBy("id") + .collect() + .map(r => (r.getLong(0), r.getAs[java.time.LocalDateTime](1))) + .toSeq + df0.createOrReplaceTempView("comet_cache_ntz") + spark.catalog.cacheTable("comet_cache_ntz") + try { + val df = spark.sql("SELECT id, ts FROM comet_cache_ntz WHERE id < 10 ORDER BY id") + checkSparkAnswer(df) + val actual = df + .collect() + .map(r => (r.getLong(0), r.getAs[java.time.LocalDateTime](1))) + .toSeq + assert(actual == expected) + assert(actual.size == 10) + } finally { + spark.catalog.uncacheTable("comet_cache_ntz") + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala index fa5f368e33..9b00e42593 100644 --- a/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/CometPluginsSuite.scala @@ -24,6 +24,8 @@ import java.io.File import org.apache.spark.sql.{CometTestBase, SaveMode} import org.apache.spark.sql.internal.StaticSQLConf +import org.apache.comet.CometConf + class CometPluginsSuite extends CometTestBase { override protected def sparkConf: SparkConf = { val conf = new SparkConf() @@ -83,6 +85,41 @@ class CometPluginsSuite extends CometTestBase { } } + test("setCacheSerializerIfEnabled installs Comet serializer when enabled and unset") { + val conf = new SparkConf().set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert( + conf.get(StaticSQLConf.SPARK_CACHE_SERIALIZER.key) == + CometDriverPlugin.COMET_CACHE_SERIALIZER) + } + + test("setCacheSerializerIfEnabled replaces the default serializer when enabled") { + val conf = new SparkConf() + .set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") + .set( + StaticSQLConf.SPARK_CACHE_SERIALIZER.key, + StaticSQLConf.SPARK_CACHE_SERIALIZER.defaultValueString) + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert( + conf.get(StaticSQLConf.SPARK_CACHE_SERIALIZER.key) == + CometDriverPlugin.COMET_CACHE_SERIALIZER) + } + + test("setCacheSerializerIfEnabled respects a user-provided serializer") { + val conf = new SparkConf() + .set(CometConf.COMET_CACHE_SERIALIZER_ENABLED.key, "true") + .set(StaticSQLConf.SPARK_CACHE_SERIALIZER.key, "com.example.MyCachedBatchSerializer") + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert( + conf.get(StaticSQLConf.SPARK_CACHE_SERIALIZER.key) == "com.example.MyCachedBatchSerializer") + } + + test("setCacheSerializerIfEnabled does nothing when disabled") { + val conf = new SparkConf() + CometDriverPlugin.setCacheSerializerIfEnabled(conf) + assert(conf.getOption(StaticSQLConf.SPARK_CACHE_SERIALIZER.key).isEmpty) + } + test("CometSource metrics are recorded") { val nativeBefore = CometSource.NATIVE_OPERATORS.getCount val queriesBefore = CometSource.QUERIES_PLANNED.getCount