diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 1b8a88beb..ae80ef2df 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -66,6 +66,7 @@ pub fn create_auron_ext_function( "Spark_MakeArray" => Arc::new(spark_make_array::array), "Spark_MapConcat" => Arc::new(spark_map::map_concat), "Spark_MapFromArrays" => Arc::new(spark_map::map_from_arrays), + "Spark_MapFromEntries" => Arc::new(spark_map::map_from_entries), "Spark_StringSpace" => Arc::new(spark_strings::string_space), "Spark_StringRepeat" => Arc::new(spark_strings::string_repeat), "Spark_StringSplit" => Arc::new(spark_strings::string_split), diff --git a/native-engine/datafusion-ext-functions/src/spark_map.rs b/native-engine/datafusion-ext-functions/src/spark_map.rs index 72aa36f04..aa3cbda16 100644 --- a/native-engine/datafusion-ext-functions/src/spark_map.rs +++ b/native-engine/datafusion-ext-functions/src/spark_map.rs @@ -13,18 +13,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use arrow::{ array::{Array, ArrayRef, ListArray, MapArray, StructArray, new_empty_array}, buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, Fields}, }; use datafusion::{ common::{Result, ScalarValue}, logical_expr::ColumnarValue, }; -use datafusion_ext_commons::{df_execution_err, scalar_value::compacted_scalar_value_from_array}; +use datafusion_ext_commons::{ + df_execution_err, downcast_any, scalar_value::compacted_scalar_value_from_array, +}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum MapKeyDedupPolicy { + Exception, + LastWin, +} fn get_map_type(args: &[ColumnarValue]) -> Result<(Arc, bool)> { if args.is_empty() { @@ -216,6 +227,189 @@ fn columnar_value_to_list_array(arg: &ColumnarValue, arg_name: &str) -> Result, + func_name: &str, +) -> Result<(Arc, Arc)> { + let fields = match list_field.data_type() { + DataType::Struct(fields) => fields, + _ => { + return df_execution_err!( + "{func_name} array entries must be struct, found {:?}", + list_field.data_type() + ); + } + }; + + if fields.len() != 2 { + return df_execution_err!( + "{func_name} array entries struct must contain exactly 2 fields, found {}", + fields.len() + ); + } + + Ok((fields[0].clone(), fields[1].clone())) +} + +fn parse_map_key_dedup_policy(args: &[ColumnarValue], idx: usize) -> Result { + if args.len() <= idx { + return Ok(MapKeyDedupPolicy::Exception); + } + + match &args[idx] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(policy))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(policy))) => match policy.as_str() { + "EXCEPTION" => Ok(MapKeyDedupPolicy::Exception), + "LAST_WIN" => Ok(MapKeyDedupPolicy::LastWin), + _ => df_execution_err!("unsupported map key dedup policy: {policy}"), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => Ok(MapKeyDedupPolicy::Exception), + _ => df_execution_err!("map key dedup policy arg must be string scalar"), + } +} + +/// Returns a map created from the given array of entries. +/// +/// This follows Spark semantics: +/// - null input array => null +/// - null array entry => null +/// - null key => error +/// - duplicate key => error by default, or last-wins when configured +pub fn map_from_entries(args: &[ColumnarValue]) -> Result { + if args.is_empty() { + return df_execution_err!("map_from_entries requires at least one argument"); + } + + let entry_arrays = columnar_value_to_list_array(&args[0], "map_from_entries")?; + let list_field = get_list_array_field(&entry_arrays, "map_from_entries")?; + let (input_key_field, input_value_field) = + extract_list_entry_fields(&list_field, "map_from_entries")?; + let key_field = Arc::new(Field::new( + "key", + input_key_field.data_type().clone(), + false, + )); + let value_field = Arc::new(Field::new( + "value", + input_value_field.data_type().clone(), + input_value_field.is_nullable(), + )); + let entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + key_field.as_ref().clone(), + value_field.as_ref().clone(), + ])), + false, + )); + + let dedup_policy = parse_map_key_dedup_policy(args, 1)?; + let num_rows = entry_arrays.len(); + + let mut all_keys = Vec::new(); + let mut all_values = Vec::new(); + let mut offsets = Vec::with_capacity(num_rows + 1); + let mut valids = Vec::with_capacity(num_rows); + let mut next_offset = 0i32; + + offsets.push(next_offset); + + for row_idx in 0..num_rows { + if entry_arrays.is_null(row_idx) { + valids.push(false); + offsets.push(next_offset); + continue; + } + + let entries = entry_arrays.value(row_idx); + let entries = entries + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Execution( + "map_from_entries expects array entries to be struct".to_string(), + ) + })?; + + let keys = entries.column(0); + let values = entries.column(1); + let mut row_entries: Vec<(ScalarValue, ScalarValue)> = Vec::new(); + let mut row_key_to_index: HashMap = HashMap::new(); + let mut row_is_null = false; + + for i in 0..entries.len() { + if entries.is_null(i) { + row_is_null = true; + break; + } + + if keys.is_null(i) { + return df_execution_err!("map_from_entries does not support null map keys"); + } + + let key = compacted_scalar_value_from_array(keys.as_ref(), i)?; + let value = compacted_scalar_value_from_array(values.as_ref(), i)?; + + if let Some(idx) = row_key_to_index.get(&key).copied() { + match dedup_policy { + MapKeyDedupPolicy::Exception => { + return df_execution_err!("map_from_entries duplicate key found: {key}"); + } + MapKeyDedupPolicy::LastWin => { + row_entries[idx].1 = value; + } + } + } else { + row_key_to_index.insert(key.clone(), row_entries.len()); + row_entries.push((key, value)); + } + } + + if row_is_null { + valids.push(false); + offsets.push(next_offset); + continue; + } + + valids.push(true); + next_offset += row_entries.len() as i32; + offsets.push(next_offset); + + for (key, value) in row_entries { + all_keys.push(key); + all_values.push(value); + } + } + + let keys = if all_keys.is_empty() { + new_empty_array(key_field.data_type()) + } else { + ScalarValue::iter_to_array(all_keys.into_iter())? + }; + + let values = if all_values.is_empty() { + new_empty_array(value_field.data_type()) + } else { + ScalarValue::iter_to_array(all_values.into_iter())? + }; + + let entries = StructArray::from(vec![(key_field, keys), (value_field, values)]); + let nulls = if valids.iter().all(|valid| *valid) { + None + } else { + Some(NullBuffer::from(valids)) + }; + + Ok(ColumnarValue::Array(Arc::new(MapArray::new( + entries_field, + OffsetBuffer::new(ScalarBuffer::from(offsets)), + entries, + nulls, + false, + )))) +} + /// Returns the union of all given maps. /// /// This follows Spark's default duplicate-key behavior by raising an error, @@ -453,7 +647,10 @@ pub fn map_from_arrays(args: &[ColumnarValue]) -> Result { #[cfg(test)] mod test { use arrow::{ - array::{Int32Array, Int32Builder, ListBuilder, NullArray, StringArray, StringBuilder}, + array::{ + Int32Array, Int32Builder, ListBuilder, NullArray, StringArray, StringBuilder, + StructBuilder, + }, datatypes::Fields, }; @@ -463,6 +660,8 @@ mod test { type StringIntMapRow = Option; type StringStringMapEntries = Vec<(&'static str, Option<&'static str>)>; type StringStringMapRow = Option; + type StringIntEntry = Option<(Option<&'static str>, Option)>; + type StringIntEntryRow = Option>; fn build_string_int_map_array(rows: Vec) -> MapArray { let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); @@ -628,6 +827,75 @@ mod test { builder.finish() } + fn build_string_int_entry_array(rows: Vec) -> ListArray { + let struct_builder = StructBuilder::new( + vec![ + Field::new("k", DataType::Utf8, true), + Field::new("v", DataType::Int32, true), + ], + vec![ + Box::new(StringBuilder::new()), + Box::new(Int32Builder::new()), + ], + ); + let mut builder = ListBuilder::new(struct_builder); + + for row in rows { + match row { + Some(entries) => { + for entry in entries { + match entry { + Some((key, value)) => { + match key { + Some(key) => builder + .values() + .field_builder::(0) + .expect("string builder") + .append_value(key), + None => builder + .values() + .field_builder::(0) + .expect("string builder") + .append_null(), + } + match value { + Some(value) => builder + .values() + .field_builder::(1) + .expect("int builder") + .append_value(value), + None => builder + .values() + .field_builder::(1) + .expect("int builder") + .append_null(), + } + builder.values().append(true); + } + None => { + builder + .values() + .field_builder::(0) + .expect("string builder") + .append_null(); + builder + .values() + .field_builder::(1) + .expect("int builder") + .append_null(); + builder.values().append(false); + } + } + } + builder.append(true); + } + None => builder.append(false), + } + } + + builder.finish() + } + #[test] fn test_map_concat() -> Result<()> { let left = build_string_int_map_array(vec![ @@ -791,4 +1059,87 @@ mod test { assert!(actual.is_null(0)); Ok(()) } + + #[test] + fn test_map_from_entries() -> Result<()> { + let entries = build_string_int_entry_array(vec![ + Some(vec![Some((Some("a"), Some(1))), Some((Some("b"), Some(2)))]), + Some(vec![Some((Some("x"), Some(10)))]), + None, + Some(vec![None, Some((Some("z"), Some(30)))]), + Some(vec![Some((Some("m"), None))]), + ]); + + let actual = map_from_entries(&[ + ColumnarValue::Array(Arc::new(entries)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ])? + .into_array(5)?; + + let expected = Arc::new(build_string_int_map_array(vec![ + Some(vec![("a", Some(1)), ("b", Some(2))]), + Some(vec![("x", Some(10))]), + None, + None, + Some(vec![("m", None)]), + ])) as ArrayRef; + + assert_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn test_map_from_entries_rejects_null_keys() { + let entries = build_string_int_entry_array(vec![Some(vec![ + Some((Some("a"), Some(1))), + Some((None, Some(2))), + ])]); + + let err = map_from_entries(&[ + ColumnarValue::Array(Arc::new(entries)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ]) + .expect_err("map_from_entries should fail when null keys exist"); + + assert!(err.to_string().contains("null map keys")); + } + + #[test] + fn test_map_from_entries_duplicate_keys() { + let entries = build_string_int_entry_array(vec![Some(vec![ + Some((Some("a"), Some(1))), + Some((Some("a"), Some(2))), + ])]); + + let err = map_from_entries(&[ + ColumnarValue::Array(Arc::new(entries)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("EXCEPTION".to_string()))), + ]) + .expect_err("map_from_entries should fail when duplicate keys exist"); + + assert!(err.to_string().contains("duplicate key")); + } + + #[test] + fn test_map_from_entries_last_win() -> Result<()> { + let entries = build_string_int_entry_array(vec![Some(vec![ + Some((Some("a"), Some(1))), + Some((Some("b"), Some(2))), + Some((Some("a"), Some(3))), + ])]); + + let actual = map_from_entries(&[ + ColumnarValue::Array(Arc::new(entries)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("LAST_WIN".to_string()))), + ])? + .into_array(1)?; + + let expected = Arc::new(build_string_int_map_array(vec![Some(vec![ + ("a", Some(3)), + ("b", Some(2)), + ])])) as ArrayRef; + + assert_eq!(&actual, &expected); + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index ef0983a58..6b6c0adf9 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -371,6 +371,86 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { } } + test("map_from_entries function") { + withTable("t1") { + sql("create table t1(c1 array>) using parquet") + sql(""" + |insert into t1 values + | (array(named_struct('k', 'a', 'v', 1), named_struct('k', 'b', 'v', 2))), + | (array(named_struct('k', 'x', 'v', 10))), + | (cast(null as array>)), + | (array(cast(null as struct), named_struct('k', 'z', 'v', 30))), + | (array(named_struct('k', 'm', 'v', cast(null as int)))) + |""".stripMargin) + checkSparkAnswerAndOperator("select map_from_entries(c1) from t1") + } + } + + test("map_from_entries rejects null keys") { + withTable("t1") { + sql("create table t1(c1 array>) using parquet") + sql(""" + |insert into t1 values + | (array(named_struct('k', 'a', 'v', 1), named_struct('k', cast(null as string), 'v', 2))) + |""".stripMargin) + val df = sql("select map_from_entries(c1) from t1") + val err = intercept[Exception] { + df.collect() + } + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan + .collectFirst { case op if !isNativeOrPassThrough(op) => op } + .foreach { op => + fail(s""" + |Found non-native operator: ${op.nodeName} + |plan: + |${plan}""".stripMargin) + } + assert(err.getMessage.toLowerCase.contains("null map key")) + } + } + + test("map_from_entries duplicate keys") { + withTable("t1") { + sql("create table t1(c1 array>) using parquet") + sql(""" + |insert into t1 values + | (array(named_struct('k', 'a', 'v', 1), named_struct('k', 'a', 'v', 2))) + |""".stripMargin) + val df = sql("select map_from_entries(c1) from t1") + val err = intercept[Exception] { + df.collect() + } + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan + .collectFirst { case op if !isNativeOrPassThrough(op) => op } + .foreach { op => + fail(s""" + |Found non-native operator: ${op.nodeName} + |plan: + |${plan}""".stripMargin) + } + assert(err.getMessage.toLowerCase.contains("duplicate key")) + } + } + + test("map_from_entries last win dedup policy") { + withTable("t1") { + sql("create table t1(c1 array>) using parquet") + sql(""" + |insert into t1 values + | (array( + | named_struct('k', 'a', 'v', 1), + | named_struct('k', 'b', 'v', 2), + | named_struct('k', 'a', 'v', 3))) + |""".stripMargin) + withSQLConf( + SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { + checkSparkAnswerAndOperator("select map_from_entries(c1) from t1") + } + } + } + test("acosh null propagation") { withTable("t1") { sql("create table t1(c1 double) using parquet") diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 6d2ba759f..f4bf67fdb 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -1107,6 +1107,14 @@ object NativeConverters extends Logging { buildExtScalarFunction("Spark_NormalizeNanAndZero", e.children, e.dataType) case e: CreateArray => buildExtScalarFunction("Spark_MakeArray", e.children, e.dataType) + case e: MapFromEntries => + buildExtScalarFunction( + "Spark_MapFromEntries", + e.child :: Literal + .create( + SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toString, + StringType) :: Nil, + e.dataType) case e: MapConcat => buildExtScalarFunction("Spark_MapConcat", e.children, e.dataType) case e: CreateNamedStruct =>