diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala new file mode 100644 index 000000000..78d0e3a98 --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +class AuronInstrSuite extends QueryTest with SparkQueryTestsBase { + + test("test instr function - basic functionality") { + val data = Seq( + ("hello world", "world"), + ("hello world", "hello"), + ("hello world", "o"), + ("hello world", "z"), + (null, "test"), + ("test", null) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val rows = df.selectExpr("instr(str, substr)").collect() + + // Check non-null results + assert(rows(0).getInt(0) == 7, "instr('hello world', 'world') should return 7") + assert(rows(1).getInt(0) == 1, "instr('hello world', 'hello') should return 1") + assert(rows(2).getInt(0) == 5, "instr('hello world', 'o') should return 5") + assert(rows(3).getInt(0) == 0, "instr('hello world', 'z') should return 0") + + // Check null results + assert(rows(4).isNullAt(0), "instr(null, 'test') should return null") + assert(rows(5).isNullAt(0), "instr('test', null) should return null") + } + + test("test instr function - multiple occurrences") { + val data = Seq( + ("banana", "a"), + ("testtesttest", "test"), + ("abcabcabc", "abc") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 2, "instr('banana', 'a') should return 2") + assert(result(1) == 1, "instr('testtesttest', 'test') should return 1") + assert(result(2) == 1, "instr('abcabcabc', 'abc') should return 1") + } + + test("test instr function - case sensitive") { + val data = Seq( + ("Hello", "hello"), + ("HELLO", "hello"), + ("Hello", "Hello"), + ("hElLo", "hello") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 0, "instr('Hello', 'hello') should return 0 (case sensitive)") + assert(result(1) == 0, "instr('HELLO', 'hello') should return 0 (case sensitive)") + assert(result(2) == 1, "instr('Hello', 'Hello') should return 1") + assert(result(3) == 0, "instr('hElLo', 'hello') should return 0 (case sensitive)") + } + + test("test instr function - with filter") { + val data = Seq( + ("hello world", "world", 1), + ("hello", "world", 0), + ("hello", "hello", 1), + ("test", "abc", 0) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr", "expected") + val result = df + .filter("instr(str, substr) > 0") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length == 2, "Should find 2 matching strings") + assert(result.contains("hello world")) + assert(result.contains("hello")) + } + + test("test instr function - in group by") { + val data = Seq( + ("test1", "test"), + ("test2", "test"), + ("hello", "world"), + ("testing", "test") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .groupBy("substr") + .count() + .filter("count > 0") + .orderBy("substr") + .collect() + + assert(result.length >= 1) + } + + test("test instr function - in where clause") { + val data = Seq( + ("hello world", "world"), + ("hello", "world"), + ("testing", "test"), + ("abc", "def") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .filter("instr(str, substr) = 1") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length >= 1) + } +} diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 1b8a88beb..aa56271f3 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -26,6 +26,7 @@ mod spark_dates; pub mod spark_get_json_object; mod spark_hash; mod spark_initcap; +mod spark_instr; mod spark_isnan; mod spark_make_array; mod spark_make_decimal; @@ -91,6 +92,7 @@ pub fn create_auron_ext_function( Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + "Spark_Instr" => Arc::new(spark_instr::spark_instr), _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs new file mode 100644 index 000000000..f0c3d2d3f --- /dev/null +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -0,0 +1,276 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; +use datafusion::{ + common::{ + Result, ScalarValue, + cast::{as_int32_array, as_string_array}, + }, + physical_plan::ColumnarValue, +}; +use datafusion_ext_commons::df_execution_err; + +/// instr(str, substr) - Returns the (1-based) index of the first occurrence of +/// substr in str. Compatible with Spark's instr function. +/// Returns 0 if substr is not found or if substr is empty. +/// Returns null if str is null or substr is null. +pub fn spark_instr(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + df_execution_err!("instr requires exactly 2 arguments")?; + } + + let is_scalar = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let len = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }) + .max() + .unwrap_or(0); + + let arrays = args + .iter() + .map(|arg| { + Ok(match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len)?, + }) + }) + .collect::>>()?; + + let str_array = as_string_array(&arrays[0])?; + let substr_array = as_string_array(&arrays[1])?; + + let result_array: ArrayRef = Arc::new(Int32Array::from_iter( + str_array + .iter() + .zip(substr_array.iter()) + .map(|(s, substr)| match (s, substr) { + (Some(_), None) => None, // substr is null + (None, _) => None, // str is null + (Some(s), Some(substr)) => { + if substr.is_empty() { + Some(0) + } else { + Some(find_char_position(s, substr)) + } + } + }), + )); + + if is_scalar { + let scalar = as_int32_array(&result_array)?.value(0); + Ok(ColumnarValue::Scalar(if result_array.is_null(0) { + ScalarValue::Int32(None) + } else { + ScalarValue::Int32(Some(scalar)) + })) + } else { + Ok(ColumnarValue::Array(result_array)) + } +} + +/// Find the 1-based character position of substr in s +/// Returns 0 if not found +fn find_char_position(s: &str, substr: &str) -> i32 { + if substr.is_empty() { + return 0; + } + + // Use char_indices to get byte offset to char position mapping + let char_positions: Vec = s.char_indices().map(|(byte_pos, _)| byte_pos).collect(); + + // Find byte offset using find + if let Some(byte_pos) = s.find(substr) { + // Find the character position (1-based) + // char_positions contains the byte offset for each character + // We need to find which character index corresponds to this byte offset + for (char_idx, &char_byte_pos) in char_positions.iter().enumerate() { + if char_byte_pos == byte_pos { + return (char_idx + 1) as i32; + } + } + // Fallback: if exact match not found, estimate + char_positions.len() as i32 + 1 + } else { + 0 + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Int32Array, StringArray}; + use datafusion::{ + common::{Result, ScalarValue, cast::as_int32_array}, + physical_plan::ColumnarValue, + }; + + use super::spark_instr; + + #[test] + fn test_spark_instr() -> Result<()> { + // Test basic functionality with scalar substring + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("abc".to_string()), + Some("abcabc".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + let s = r.into_array(4)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(0), None,] + ); + + // Test with empty substring should return 0 + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello".to_string()), + Some("world".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from("")), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0), None,] + ); + + // Test with null substring + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ])?; + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![None,] + ); + + // Test with array substring (element-wise) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("hello".to_string()), + Some("test".to_string()), + ]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("world".to_string()), + Some("test".to_string()), + Some("test".to_string()), + ]))), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(1),] + ); + + // Test with both scalars + let r = spark_instr(&vec![ + ColumnarValue::Scalar(ScalarValue::from("hello world")), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + assert!(matches!( + r, + ColumnarValue::Scalar(ScalarValue::Int32(Some(7))) + )); + + Ok(()) + } + + #[test] + fn test_spark_instr_multiple_matches() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("banana".to_string()), + Some("testtesttest".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("test")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(1),] + ); + Ok(()) + } + + #[test] + fn test_spark_instr_case_sensitive() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("Hello".to_string()), + Some("HELLO".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("hello")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0),] + ); + Ok(()) + } + + #[test] + fn test_spark_instr_utf8() -> Result<()> { + // Test UTF-8 multi-byte characters + // "你好世界" - "世界" should return 3 (character position), not 6 (byte + // position) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("你好世界".to_string()), + Some("hello世界".to_string()), + Some("test".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("世界")), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(3), Some(6), Some(0),] + ); + + // Test with emoji (4-byte UTF-8) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello😀world".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::from("😀")), + ])?; + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(6),] + ); + + Ok(()) + } +} diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index ef0983a58..5b26c9838 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -868,6 +868,117 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { |""".stripMargin checkSparkAnswerAndOperator(query) } + + test("instr function - basic functionality") { + withTable("t1") { + sql(""" + CREATE TABLE t1(str STRING, substr STRING) USING parquet + """) + sql(""" + INSERT INTO t1 VALUES + ('hello world', 'world'), + ('hello world', 'hello'), + ('hello world', 'o'), + ('hello world', 'z'), + (null, 'test'), + ('test', null) + """) + + // Test basic instr functionality + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - empty substring") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('hello'), ('world'), ('')") + + // Empty substring should return 0 + checkSparkAnswerAndOperator("SELECT instr(str, '') FROM t1") + } + } + + test("instr function - UTF-8 multi-byte characters") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('你好世界', '世界'), + ('hello世界', '世界'), + ('test', '世界'), + ('hello😀world', '😀'), + ('test😀', '😀') + """) + + // Test UTF-8 character position (not byte position) + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - with expressions") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('banana', 'a'), ('testtesttest', 'test'), ('abcabcabc', 'abc')") + + // Test with array column as substring (element-wise) + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - case sensitivity") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('Hello', 'hello'), + ('HELLO', 'hello'), + ('Hello', 'Hello'), + ('hElLo', 'hello') + """) + + // Instr is case-sensitive + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - in filter clause") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('hello world', 'world'), + ('hello', 'world'), + ('testing', 'test'), + ('abc', 'def') + """) + + // Test instr in WHERE clause + checkSparkAnswerAndOperator(""" + SELECT str FROM t1 WHERE instr(str, substr) > 0 + """) + } + } + + test("instr function - with grouping") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('test1', 'test'), + ('test2', 'test'), + ('hello', 'world'), + ('testing', 'test') + """) + + // Test instr in GROUP BY + checkSparkAnswerAndOperator(""" + SELECT substr, COUNT(*) as cnt + FROM t1 + WHERE instr(str, substr) > 0 + GROUP BY substr + ORDER BY substr + """) } } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 6d2ba759f..26c620f4a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -961,6 +961,8 @@ object NativeConverters extends Logging { case e: Levenshtein => buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType) + case e: StringInstr => + buildExtScalarFunction("Spark_Instr", e.children, e.dataType) case e: Hour if datetimeExtractEnabled => buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback)