diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index bcc9ee0692..e4f341efa8 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -21,7 +21,7 @@ package org.apache.comet.expressions import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Literal} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, MapType, NullType, StructType, TimestampNTZType, TimestampType} import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isSpark40Plus, withFallbackReason} @@ -200,6 +200,14 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { } } Compatible() + case (from_map: MapType, to_map: MapType) => + // Native cast_map_to_map recursively casts keys and values, so support is + // determined by whether both inner casts are individually supported. + isSupported(from_map.keyType, to_map.keyType, timeZoneId, evalMode) match { + case Compatible(_) => + isSupported(from_map.valueType, to_map.valueType, timeZoneId, evalMode) + case other => other + } case (DataTypes.DateType, toType) => canCastFromDate(toType, evalMode) case _ => unsupported(fromType, toType) } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index c527858507..aac1bc0081 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions.{col, monotonically_increasing_id} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructField, StructType, TimestampType} import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.rules.CometScanTypeChecker @@ -1632,6 +1632,61 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { testArrayCastMatrix(types, ArrayType(_), generateArrays(100, _)) } + test("cast MapType to MapType") { + // https://github.com/apache/datafusion-comet/issues/4491 + // Native cast_map_to_map already handles the Parquet `key_value` vs + // Spark `entries` field-name difference, so we only need to verify that + // the planner routes Map→Map casts into it. The map column must be read + // natively for the cast to be exercised by Comet, which only happens + // under the V1 Parquet scan, so we pin USE_V1_SOURCE_LIST=parquet. + import scala.collection.JavaConverters._ + val schema = + StructType(Seq(StructField("a", MapType(IntegerType, IntegerType), nullable = true))) + val rows = Range(0, 100).map { i => + if (i % 10 == 0) Row(null) + else if (i % 7 == 0) Row(Map.empty[Int, Int]) + else Row(Map(i -> (i + 1), (i + 2) -> (i + 3))) + } + val input = spark.createDataFrame(rows.asJava, schema) + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { + Seq( + MapType(LongType, LongType), + MapType(IntegerType, StringType), + MapType(StringType, DoubleType)).foreach { toType => + castTest(input, toType) + } + } + } + + test("cast MapType propagates Incompatible from inner value cast") { + // Float → Decimal is Incompatible due to rounding (see canCastFromFloat). + // The Map arm must propagate that Incompatible up rather than silently + // marking the whole Map → Map cast Compatible. + assert( + CometCast.isSupported( + MapType(IntegerType, FloatType), + MapType(IntegerType, DecimalType(10, 2)), + None, + CometEvalMode.LEGACY) == + Incompatible(Some("There can be rounding differences"))) + } + + test("cast MapType propagates Unsupported from nested value cast") { + // Map> → Map: the inner Map → String + // cast is Unsupported, and that must propagate through the outer Map + // arm rather than being silently swallowed. + val innerFrom = MapType(IntegerType, IntegerType) + val expectedMessage = s"Cast from $innerFrom to ${DataTypes.StringType} is not supported" + assert( + CometCast.isSupported( + MapType(IntegerType, innerFrom), + MapType(IntegerType, StringType), + None, + CometEvalMode.LEGACY) == + Unsupported(Some(expectedMessage))) + } + test("cast ArrayType(DateType) to unsupported ArrayType falls back") { val fromType = ArrayType(DateType) val unsupportedElementTypes =