diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 86e70cb330..3a22370ad8 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -186,10 +186,10 @@ The tables below list every Spark built-in expression with its current status. | Function | Status | Notes | | --- | --- | --- | | `array_size` | ✅ | | -| `cardinality` | ✅ | MapType input falls back | +| `cardinality` | ✅ | | | `concat` | ✅ | Binary/array children fall back | | `reverse` | ✅ | Binary-element arrays fall back (Incompatible) ([details](compatibility/expressions/array.md)) | -| `size` | ✅ | MapType input falls back | +| `size` | ✅ | | --- diff --git a/native/spark-expr/src/array_funcs/size.rs b/native/spark-expr/src/array_funcs/size.rs index 9777553341..9adcc29800 100644 --- a/native/spark-expr/src/array_funcs/size.rs +++ b/native/spark-expr/src/array_funcs/size.rs @@ -198,6 +198,14 @@ fn spark_size_scalar(scalar: &ScalarValue) -> Result { + if array.is_null(0) { + Ok(ScalarValue::Int32(Some(-1))) + } else { + let len = array.value_length(0); + Ok(ScalarValue::Int32(Some(len))) + } + } ScalarValue::Null => { Ok(ScalarValue::Int32(Some(-1))) // Spark behavior: return -1 for null } @@ -276,78 +284,130 @@ mod tests { assert_eq!(result, ScalarValue::Int32(Some(-1))); } - // TODO: Add map array test once Arrow MapArray API constraints are resolved - // Currently MapArray doesn't allow nulls in entries which makes testing complex - // The core size() implementation supports maps correctly - #[ignore] #[test] fn test_spark_size_map_array() { - use arrow::array::{MapArray, StringArray}; - - // Create a simpler test with maps: - // [{"key1": "value1", "key2": "value2"}, {"key3": "value3"}, {}, null] + use arrow::array::{Int32Array, MapArray, StringArray}; - // Create keys array for all entries (no nulls) - let keys = StringArray::from(vec!["key1", "key2", "key3"]); + let keys = StringArray::from(vec![Some("key1"), Some("key2"), Some("key3")]); + let values = Int32Array::from(vec![Some(1), Some(2), Some(3)]); - // Create values array for all entries (no nulls) - let values = StringArray::from(vec!["value1", "value2", "value3"]); - - // Create entry offsets: [0, 2, 3, 3] representing: - // - Map 1: entries 0-1 (2 key-value pairs) - // - Map 2: entries 2-2 (1 key-value pair) - // - Map 3: entries 3-2 (0 key-value pairs, empty map) - // - Map 4: null (handled by null buffer) - let entry_offsets = arrow::buffer::OffsetBuffer::new(vec![0, 2, 3, 3, 3].into()); + let entry_offsets = arrow::buffer::OffsetBuffer::new(vec![0i32, 2, 3, 3, 3].into()); let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); - let value_field = Arc::new(Field::new("value", DataType::Utf8, false)); // Make values non-nullable too + let value_field = Arc::new(Field::new("value", DataType::Int32, true)); - // Create the entries struct array let entries = arrow::array::StructArray::new( arrow::datatypes::Fields::from(vec![key_field, value_field]), vec![Arc::new(keys), Arc::new(values)], - None, // No nulls in the entries struct array itself + None, ); - // Create null buffer for the map array (fourth map is null) let mut null_buffer = NullBufferBuilder::new(4); - null_buffer.append(true); // Map with 2 entries - not null - null_buffer.append(true); // Map with 1 entry - not null - null_buffer.append(true); // Empty map - not null - null_buffer.append(false); // null map - - let map_data_type = DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct(arrow::datatypes::Fields::from(vec![ - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Utf8, false), // Make values non-nullable too - ])), - false, - )), - false, // keys are not sorted - ); - - let map_field = Arc::new(Field::new("map", map_data_type, true)); - - let map_array = MapArray::new( + null_buffer.append(true); + null_buffer.append(true); + null_buffer.append(true); + null_buffer.append(false); + + let map_field = Arc::new(Field::new( + "entries", + DataType::Struct(arrow::datatypes::Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )); + + let map_array = MapArray::try_new( map_field, entry_offsets, entries, null_buffer.finish(), - false, // keys are not sorted - ); + false, + ) + .unwrap(); let array_ref: ArrayRef = Arc::new(map_array); let result = spark_size_array(&array_ref).unwrap(); let result = result.as_any().downcast_ref::().unwrap(); - // Expected: [2, 1, 0, -1] - assert_eq!(result.value(0), 2); // Map with 2 key-value pairs - assert_eq!(result.value(1), 1); // Map with 1 key-value pair - assert_eq!(result.value(2), 0); // empty map has 0 pairs - assert_eq!(result.value(3), -1); // null map returns -1 + assert_eq!(result.value(0), 2); + assert_eq!(result.value(1), 1); + assert_eq!(result.value(2), 0); + assert_eq!(result.value(3), -1); + } + + #[test] + fn test_spark_size_scalar_map() { + use arrow::array::{Int32Array, MapArray, StringArray}; + + let keys = StringArray::from(vec![Some("a"), Some("b")]); + let values = Int32Array::from(vec![Some(1), Some(2)]); + let entry_offsets = arrow::buffer::OffsetBuffer::new(vec![0i32, 2].into()); + + let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("value", DataType::Int32, true)); + + let entries = arrow::array::StructArray::new( + arrow::datatypes::Fields::from(vec![key_field, value_field]), + vec![Arc::new(keys), Arc::new(values)], + None, + ); + + let map_field = Arc::new(Field::new( + "entries", + DataType::Struct(arrow::datatypes::Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )); + + let map_array = MapArray::try_new(map_field, entry_offsets, entries, None, false).unwrap(); + let scalar = ScalarValue::Map(Arc::new(map_array)); + let result = spark_size_scalar(&scalar).unwrap(); + assert_eq!(result, ScalarValue::Int32(Some(2))); + } + + #[test] + fn test_spark_size_scalar_null_map() { + use arrow::array::{Int32Array, MapArray, StringArray}; + + let keys = StringArray::from(vec![Some("a")]); + let values = Int32Array::from(vec![Some(1)]); + let entry_offsets = arrow::buffer::OffsetBuffer::new(vec![0i32, 1].into()); + + let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("value", DataType::Int32, true)); + + let entries = arrow::array::StructArray::new( + arrow::datatypes::Fields::from(vec![key_field, value_field]), + vec![Arc::new(keys), Arc::new(values)], + None, + ); + + let map_field = Arc::new(Field::new( + "entries", + DataType::Struct(arrow::datatypes::Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )); + + let mut null_buffer = NullBufferBuilder::new(1); + null_buffer.append(false); + + let map_array = MapArray::try_new( + map_field, + entry_offsets, + entries, + null_buffer.finish(), + false, + ) + .unwrap(); + let scalar = ScalarValue::Map(Arc::new(map_array)); + let result = spark_size_scalar(&scalar).unwrap(); + assert_eq!(result, ScalarValue::Int32(Some(-1))); } #[test] diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 2f89b0e2e3..bcda62ee09 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -667,15 +667,11 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { object CometSize extends CometExpressionSerde[Size] { - override def getUnsupportedReasons(): Seq[String] = Seq( - "Only supports `ArrayType` input; `MapType` input is not supported") - override def getSupportLevel(expr: Size): SupportLevel = { expr.child.dataType match { case _: ArrayType => Compatible() - case _: MapType => Unsupported(Some("size does not support map inputs")) + case _: MapType => Compatible() case other => - // this should be unreachable because Spark only supports map and array inputs Unsupported(Some(s"Unsupported child data type: $other")) } } diff --git a/spark/src/test/resources/sql-tests/expressions/array/posexplode.sql b/spark/src/test/resources/sql-tests/expressions/array/posexplode.sql index 20b0547ebe..b7ba70d45a 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/posexplode.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/posexplode.sql @@ -95,6 +95,6 @@ INSERT INTO test_posexplode_map VALUES (1, map('a', 1, 'b', 2)), (2, map('c', 3)) --- posexplode over a map falls back to Spark (Comet only supports array inputs) -query expect_fallback(size does not support map inputs) +-- posexplode over a map falls back to Spark (Comet only supports array inputs, not maps) +query expect_fallback(Comet only supports explode/explode_outer for arrays, not maps) SELECT id, posexplode(m) FROM test_posexplode_map diff --git a/spark/src/test/resources/sql-tests/expressions/array/size.sql b/spark/src/test/resources/sql-tests/expressions/array/size.sql index b006a4da0d..fb2bedef1e 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/size.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/size.sql @@ -15,15 +15,29 @@ -- specific language governing permissions and limitations -- under the License. +-- ConfigMatrix: spark.sql.legacy.sizeOfNull=true,false + statement CREATE TABLE test_size(arr array, m map) USING parquet statement INSERT INTO test_size VALUES (array(1, 2, 3), map('a', 1, 'b', 2)), (array(), map()), (NULL, NULL) -query spark_answer_only +query SELECT size(arr), size(m) FROM test_size --- literal arguments +-- literal array arguments query SELECT size(array(1, 2, 3)), size(array()), size(cast(NULL as array)) + +-- literal map via CreateMap (falls back: Comet has no CreateMap serde; +-- cast(NULL as map) avoids CreateMap and goes through CometLiteral instead) +query spark_answer_only +SELECT size(map('a', 1, 'b', 2)), size(map()) + +query +SELECT size(cast(NULL as map)) + +-- cardinality is a SQL alias for size +query +SELECT cardinality(arr), cardinality(m) FROM test_size diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index f3c7d9f23e..86980fffe6 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -133,27 +133,43 @@ class CometMapExpressionSuite extends CometTestBase { makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) spark.read.parquet(path.toString).createOrReplaceTempView("t1") - // Use column references in maps to avoid constant folding checkSparkAnswerAndFallbackReason( sql("SELECT size(case when _2 < 0 then map(_8, _9) else map() end) from t1"), - "size does not support map inputs") + "map is not supported") } } } - // fails with "map is not supported" - ignore("size with map input") { - withTempDir { dir => - withTempView("t1") { - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") + test("size with map input - v2 reader") { + withTempPath { dir => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val df = spark + .range(100) + .select( + col("id"), + when(col("id") > 1, map(col("id"), when(col("id") > 2, col("id")))) + .alias("map1"), + when(col("id") > 5, map(lit("a"), col("id"), lit("b"), col("id") + 1)) + .alias("map2")) + df.write.parquet(dir.toString()) + } - // Use column references in maps to avoid constant folding - checkSparkAnswerAndOperator( - sql("SELECT size(map(_8, _9, _10, _11)) from t1 where _8 is not null")) - checkSparkAnswerAndOperator( - sql("SELECT size(case when _2 < 0 then map(_8, _9) else map() end) from t1")) + Seq("", "parquet").foreach { v1List => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) { + val df = spark.read.parquet(dir.toString()) + df.createOrReplaceTempView("t1") + if (v1List.isEmpty) { + checkSparkAnswer(df.select(size(col("map1")))) + checkSparkAnswer(df.select(size(col("map2")))) + checkSparkAnswer( + sql("SELECT size(CASE WHEN id < 50 THEN map1 ELSE map2 END) FROM t1")) + } else { + checkSparkAnswerAndOperator(df.select(size(col("map1")))) + checkSparkAnswerAndOperator(df.select(size(col("map2")))) + checkSparkAnswerAndOperator( + sql("SELECT size(CASE WHEN id < 50 THEN map1 ELSE map2 END) FROM t1")) + } + } } } }