From a876716fb4b0bb6696519b312fbb5b28129e8120 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 16:39:46 +0800 Subject: [PATCH 1/7] [AURON #2067] Implement native function of instr --- .../apache/spark/sql/AuronInstrSuite.scala | 130 +++++++++++++++++ native-engine/auron/src/exec.rs | 9 +- .../datafusion-ext-functions/src/lib.rs | 2 + .../src/spark_instr.rs | 136 ++++++++++++++++++ .../spark/sql/auron/NativeConverters.scala | 2 + 5 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala create mode 100644 native-engine/datafusion-ext-functions/src/spark_instr.rs diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala new file mode 100644 index 000000000..cc62e5b68 --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +class AuronInstrSuite extends QueryTest with SparkQueryTestsBase { + + test("test instr function - basic functionality") { + val data = Seq( + ("hello world", "world"), + ("hello world", "hello"), + ("hello world", "o"), + ("hello world", "z"), + (null, "test"), + ("test", null) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 7, "instr('hello world', 'world') should return 7") + assert(result(1) == 1, "instr('hello world', 'hello') should return 1") + assert(result(2) == 5, "instr('hello world', 'o') should return 5") + assert(result(3) == 0, "instr('hello world', 'z') should return 0") + assert(result(4) == 0, "instr(null, 'test') should return null") + assert(result(5) == 0, "instr('test', null) should return null") + } + + test("test instr function - multiple occurrences") { + val data = Seq( + ("banana", "a"), + ("testtesttest", "test"), + ("abcabcabc", "abc") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 2, "instr('banana', 'a') should return 2") + assert(result(1) == 1, "instr('testtesttest', 'test') should return 1") + assert(result(2) == 1, "instr('abcabcabc', 'abc') should return 1") + } + + test("test instr function - case sensitive") { + val data = Seq( + ("Hello", "hello"), + ("HELLO", "hello"), + ("Hello", "Hello"), + ("hElLo", "hello") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 0, "instr('Hello', 'hello') should return 0 (case sensitive)") + assert(result(1) == 0, "instr('HELLO', 'hello') should return 0 (case sensitive)") + assert(result(2) == 1, "instr('Hello', 'Hello') should return 1") + assert(result(3) == 0, "instr('hElLo', 'hello') should return 0 (case sensitive)") + } + + test("test instr function - with filter") { + val data = Seq( + ("hello world", "world", 1), + ("hello", "world", 0), + ("hello", "hello", 1), + ("test", "abc", 0) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr", "expected") + val result = df + .filter("instr(str, substr) > 0") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length == 2, "Should find 2 matching strings") + assert(result.contains("hello world")) + assert(result.contains("hello")) + } + + test("test instr function - in group by") { + val data = Seq( + ("test1", "test"), + ("test2", "test"), + ("hello", "world"), + ("testing", "test") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .groupBy("substr") + .count() + .filter("count > 0") + .orderBy("substr") + .collect() + + assert(result.length >= 1) + } + + test("test instr function - in where clause") { + val data = Seq( + ("hello world", "world"), + ("hello", "world"), + ("testing", "test"), + ("abc", "def") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .filter("instr(str, substr) = 1") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length >= 1) + } +} diff --git a/native-engine/auron/src/exec.rs b/native-engine/auron/src/exec.rs index fa4fec4af..ee51eba05 100644 --- a/native-engine/auron/src/exec.rs +++ b/native-engine/auron/src/exec.rs @@ -141,9 +141,16 @@ pub extern "system" fn Java_org_apache_auron_jni_JniBridge_finalizeNative( #[allow(non_snake_case)] #[unsafe(no_mangle)] -pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(_: JNIEnv, _: JClass) { +pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(env: JNIEnv, _: JClass) { log::info!("exiting native environment"); if MemManager::initialized() { MemManager::get().dump_status(); } + // Clear Java-side resources to prevent memory leaks + let _ = env.call_static_method( + jni_bridge::JavaClasses::get().cJniBridge.class, + "clearResources", + "()V", + &[] + ); } diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 1b8a88beb..aa56271f3 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -26,6 +26,7 @@ mod spark_dates; pub mod spark_get_json_object; mod spark_hash; mod spark_initcap; +mod spark_instr; mod spark_isnan; mod spark_make_array; mod spark_make_decimal; @@ -91,6 +92,7 @@ pub fn create_auron_ext_function( Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + "Spark_Instr" => Arc::new(spark_instr::spark_instr), _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs new file mode 100644 index 000000000..db1a9b96b --- /dev/null +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; +use datafusion::{ + common::{Result, ScalarValue, cast::as_string_array}, + physical_plan::ColumnarValue, +}; +use datafusion_ext_commons::df_execution_err; + +/// instr(str, substr) - Returns the (1-based) index of the first occurrence of +/// substr in str Compatible with Spark's instr function +/// Returns 0 if substr is not found or if either argument is null +pub fn spark_instr(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + df_execution_err!("instr requires exactly 2 arguments")?; + } + + let string_array = args[0].clone().into_array(1)?; + let substr = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) if !substr.is_empty() => substr, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); + } + _ => df_execution_err!("instr substring only supports non-empty literal string")?, + }; + + let result_array: ArrayRef = Arc::new(Int32Array::from_iter( + as_string_array(&string_array)? + .into_iter() + .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), + )); + + Ok(ColumnarValue::Array(result_array)) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::StringArray; + use datafusion::{ + common::{Result, ScalarValue, cast::as_int32_array}, + physical_plan::ColumnarValue, + }; + + use super::spark_instr; + + #[test] + fn test_spark_instr() -> Result<()> { + // Test basic functionality + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("abc".to_string()), + Some("abcabc".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + let s = r.into_array(4)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(0), None,] + ); + + // Test with empty substring should return 0 + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::from("")), + ]); + assert!(r.is_err()); + + // Test with null substring + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ])?; + if !matches!(r, ColumnarValue::Scalar(ScalarValue::Int32(None))) { + return datafusion::common::internal_err!("Expected null Int32 scalar"); + } + Ok(()) + } + + #[test] + fn test_spark_instr_multiple_matches() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("banana".to_string()), + Some("testtesttest".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("test")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(1),] + ); + Ok(()) + } + + #[test] + fn test_spark_instr_case_sensitive() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("Hello".to_string()), + Some("HELLO".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("hello")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0),] + ); + Ok(()) + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 6d2ba759f..26c620f4a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -961,6 +961,8 @@ object NativeConverters extends Logging { case e: Levenshtein => buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType) + case e: StringInstr => + buildExtScalarFunction("Spark_Instr", e.children, e.dataType) case e: Hour if datetimeExtractEnabled => buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback) From 9a207be96e9bd0f75721dc6afa7927786e55c35f Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 16:42:36 +0800 Subject: [PATCH 2/7] [AURON #2067] Implement native function of instr --- native-engine/auron/src/exec.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/native-engine/auron/src/exec.rs b/native-engine/auron/src/exec.rs index ee51eba05..fa4fec4af 100644 --- a/native-engine/auron/src/exec.rs +++ b/native-engine/auron/src/exec.rs @@ -141,16 +141,9 @@ pub extern "system" fn Java_org_apache_auron_jni_JniBridge_finalizeNative( #[allow(non_snake_case)] #[unsafe(no_mangle)] -pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(env: JNIEnv, _: JClass) { +pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(_: JNIEnv, _: JClass) { log::info!("exiting native environment"); if MemManager::initialized() { MemManager::get().dump_status(); } - // Clear Java-side resources to prevent memory leaks - let _ = env.call_static_method( - jni_bridge::JavaClasses::get().cJniBridge.class, - "clearResources", - "()V", - &[] - ); } From 9a42e0d78e4921d14b4f6ca19c2ddc4a9ad8d967 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 17:23:48 +0800 Subject: [PATCH 3/7] fix tests --- .../src/spark_instr.rs | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index db1a9b96b..34bf6ae0f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -23,8 +23,9 @@ use datafusion::{ use datafusion_ext_commons::df_execution_err; /// instr(str, substr) - Returns the (1-based) index of the first occurrence of -/// substr in str Compatible with Spark's instr function -/// Returns 0 if substr is not found or if either argument is null +/// substr in str. Compatible with Spark's instr function. +/// Returns 0 if substr is not found or if substr is empty. +/// Returns null if str is null. pub fn spark_instr(args: &[ColumnarValue]) -> Result { if args.len() != 2 { df_execution_err!("instr requires exactly 2 arguments")?; @@ -32,18 +33,27 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { let string_array = args[0].clone().into_array(1)?; let substr = match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) if !substr.is_empty() => substr, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) => substr, ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); } - _ => df_execution_err!("instr substring only supports non-empty literal string")?, + _ => df_execution_err!("instr substring only supports literal string")?, }; - let result_array: ArrayRef = Arc::new(Int32Array::from_iter( - as_string_array(&string_array)? - .into_iter() - .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), - )); + // If substr is empty, return 0 for all non-null strings + let result_array: ArrayRef = if substr.is_empty() { + Arc::new(Int32Array::from_iter( + as_string_array(&string_array)? + .into_iter() + .map(|s| s.map(|_| 0)), + )) + } else { + Arc::new(Int32Array::from_iter( + as_string_array(&string_array)? + .into_iter() + .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), + )) + }; Ok(ColumnarValue::Array(result_array)) } @@ -80,12 +90,18 @@ mod test { // Test with empty substring should return 0 let r = spark_instr(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( - "hello".to_string(), - )]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello".to_string()), + Some("world".to_string()), + None, + ]))), ColumnarValue::Scalar(ScalarValue::from("")), - ]); - assert!(r.is_err()); + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0), None,] + ); // Test with null substring let r = spark_instr(&vec![ From 5eeaf067c69b2085caf22abcf0db2f027934e721 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 18:01:47 +0800 Subject: [PATCH 4/7] fix tests --- .../src/spark_instr.rs | 110 +++++++++++++----- 1 file changed, 83 insertions(+), 27 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index 34bf6ae0f..fbdc4a01f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -17,7 +17,10 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; use datafusion::{ - common::{Result, ScalarValue, cast::as_string_array}, + common::{ + Result, ScalarValue, + cast::{as_int32_array, as_string_array}, + }, physical_plan::ColumnarValue, }; use datafusion_ext_commons::df_execution_err; @@ -25,37 +28,58 @@ use datafusion_ext_commons::df_execution_err; /// instr(str, substr) - Returns the (1-based) index of the first occurrence of /// substr in str. Compatible with Spark's instr function. /// Returns 0 if substr is not found or if substr is empty. -/// Returns null if str is null. +/// Returns null if str is null or substr is null. pub fn spark_instr(args: &[ColumnarValue]) -> Result { if args.len() != 2 { df_execution_err!("instr requires exactly 2 arguments")?; } - let string_array = args[0].clone().into_array(1)?; - let substr = match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) => substr, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); - } - _ => df_execution_err!("instr substring only supports literal string")?, + // Ensure both arguments are arrays for element-wise comparison + let left: ArrayRef = match &args[0] { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + ColumnarValue::Array(array) => array.clone(), }; - // If substr is empty, return 0 for all non-null strings - let result_array: ArrayRef = if substr.is_empty() { - Arc::new(Int32Array::from_iter( - as_string_array(&string_array)? - .into_iter() - .map(|s| s.map(|_| 0)), - )) - } else { - Arc::new(Int32Array::from_iter( - as_string_array(&string_array)? - .into_iter() - .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), - )) + let right: ArrayRef = match &args[1] { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(left.len())?, + ColumnarValue::Array(array) => array.clone(), }; - Ok(ColumnarValue::Array(result_array)) + let str_array = as_string_array(&left)?; + let substr_array = as_string_array(&right)?; + + // Perform element-wise instr operation + let result_array: ArrayRef = Arc::new(Int32Array::from_iter( + str_array + .into_iter() + .zip(substr_array.into_iter()) + .map(|(s, substr)| { + match (s, substr) { + (Some(_), None) => None, // substr is null + (None, _) => None, // str is null + (Some(s), Some(substr)) => { + // Empty substr returns 0 + if substr.is_empty() { + Some(0) + } else { + Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) + } + } + } + }), + )); + + // If both inputs were scalars, return a scalar + if matches!(args[0], ColumnarValue::Scalar(_)) && matches!(args[1], ColumnarValue::Scalar(_)) { + let scalar = as_int32_array(&result_array)?.value(0); + Ok(ColumnarValue::Scalar(if result_array.is_null(0) { + ScalarValue::Int32(None) + } else { + ScalarValue::Int32(Some(scalar)) + })) + } else { + Ok(ColumnarValue::Array(result_array)) + } } #[cfg(test)] @@ -72,7 +96,7 @@ mod test { #[test] fn test_spark_instr() -> Result<()> { - // Test basic functionality + // Test basic functionality with scalar substring let r = spark_instr(&vec![ ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ Some("hello world".to_string()), @@ -110,9 +134,41 @@ mod test { )]))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ])?; - if !matches!(r, ColumnarValue::Scalar(ScalarValue::Int32(None))) { - return datafusion::common::internal_err!("Expected null Int32 scalar"); - } + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![None,] + ); + + // Test with array substring (element-wise) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("hello".to_string()), + Some("test".to_string()), + ]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("world".to_string()), + Some("test".to_string()), + Some("test".to_string()), + ]))), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(1),] + ); + + // Test with both scalars + let r = spark_instr(&vec![ + ColumnarValue::Scalar(ScalarValue::from("hello world")), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + assert!(matches!( + r, + ColumnarValue::Scalar(ScalarValue::Int32(Some(7))) + )); + Ok(()) } From a0da01d9faedfd5e5a7eaa3f542b5ce94cbd9c46 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 19:07:08 +0800 Subject: [PATCH 5/7] fix styles --- .../src/spark_instr.rs | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index fbdc4a01f..69a74b59b 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -34,43 +34,49 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { df_execution_err!("instr requires exactly 2 arguments")?; } - // Ensure both arguments are arrays for element-wise comparison - let left: ArrayRef = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, - ColumnarValue::Array(array) => array.clone(), - }; - - let right: ArrayRef = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(left.len())?, - ColumnarValue::Array(array) => array.clone(), - }; - - let str_array = as_string_array(&left)?; - let substr_array = as_string_array(&right)?; + let is_scalar = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let len = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }) + .max() + .unwrap_or(0); + + let arrays = args + .iter() + .map(|arg| { + Ok(match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len)?, + }) + }) + .collect::>>()?; + + let str_array = as_string_array(&arrays[0])?; + let substr_array = as_string_array(&arrays[1])?; - // Perform element-wise instr operation let result_array: ArrayRef = Arc::new(Int32Array::from_iter( str_array .into_iter() .zip(substr_array.into_iter()) - .map(|(s, substr)| { - match (s, substr) { - (Some(_), None) => None, // substr is null - (None, _) => None, // str is null - (Some(s), Some(substr)) => { - // Empty substr returns 0 - if substr.is_empty() { - Some(0) - } else { - Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) - } + .map(|(s, substr)| match (s, substr) { + (Some(_), None) => None, // substr is null + (None, _) => None, // str is null + (Some(s), Some(substr)) => { + if substr.is_empty() { + Some(0) + } else { + Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) } } }), )); - // If both inputs were scalars, return a scalar - if matches!(args[0], ColumnarValue::Scalar(_)) && matches!(args[1], ColumnarValue::Scalar(_)) { + if is_scalar { let scalar = as_int32_array(&result_array)?.value(0); Ok(ColumnarValue::Scalar(if result_array.is_null(0) { ScalarValue::Int32(None) @@ -86,7 +92,7 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { mod test { use std::sync::Arc; - use arrow::array::StringArray; + use arrow::array::{ArrayRef, Int32Array, StringArray}; use datafusion::{ common::{Result, ScalarValue, cast::as_int32_array}, physical_plan::ColumnarValue, From 8fa388cf81dfabb6e797485d396f2395f0b70c90 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 19:11:31 +0800 Subject: [PATCH 6/7] fix styles --- native-engine/datafusion-ext-functions/src/spark_instr.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index 69a74b59b..b970cc38f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -61,8 +61,8 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { let result_array: ArrayRef = Arc::new(Int32Array::from_iter( str_array - .into_iter() - .zip(substr_array.into_iter()) + .iter() + .zip(substr_array.iter()) .map(|(s, substr)| match (s, substr) { (Some(_), None) => None, // substr is null (None, _) => None, // str is null @@ -70,7 +70,7 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { if substr.is_empty() { Some(0) } else { - Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) + Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) } } }), From 094d8244ee323d5983a88a948c2c7a14970b628e Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 8 Apr 2026 21:29:43 +0800 Subject: [PATCH 7/7] Fix --- .../apache/spark/sql/AuronInstrSuite.scala | 19 +-- .../src/spark_instr.rs | 64 +++++++++- .../org/apache/auron/AuronFunctionSuite.scala | 111 ++++++++++++++++++ 3 files changed, 185 insertions(+), 9 deletions(-) diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala index cc62e5b68..78d0e3a98 100644 --- a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala @@ -29,14 +29,17 @@ class AuronInstrSuite extends QueryTest with SparkQueryTestsBase { ) val df = spark.createDataFrame(data).toDF("str", "substr") - val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) - - assert(result(0) == 7, "instr('hello world', 'world') should return 7") - assert(result(1) == 1, "instr('hello world', 'hello') should return 1") - assert(result(2) == 5, "instr('hello world', 'o') should return 5") - assert(result(3) == 0, "instr('hello world', 'z') should return 0") - assert(result(4) == 0, "instr(null, 'test') should return null") - assert(result(5) == 0, "instr('test', null) should return null") + val rows = df.selectExpr("instr(str, substr)").collect() + + // Check non-null results + assert(rows(0).getInt(0) == 7, "instr('hello world', 'world') should return 7") + assert(rows(1).getInt(0) == 1, "instr('hello world', 'hello') should return 1") + assert(rows(2).getInt(0) == 5, "instr('hello world', 'o') should return 5") + assert(rows(3).getInt(0) == 0, "instr('hello world', 'z') should return 0") + + // Check null results + assert(rows(4).isNullAt(0), "instr(null, 'test') should return null") + assert(rows(5).isNullAt(0), "instr('test', null) should return null") } test("test instr function - multiple occurrences") { diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index b970cc38f..f0c3d2d3f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -70,7 +70,7 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { if substr.is_empty() { Some(0) } else { - Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) + Some(find_char_position(s, substr)) } } }), @@ -88,6 +88,33 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { } } +/// Find the 1-based character position of substr in s +/// Returns 0 if not found +fn find_char_position(s: &str, substr: &str) -> i32 { + if substr.is_empty() { + return 0; + } + + // Use char_indices to get byte offset to char position mapping + let char_positions: Vec = s.char_indices().map(|(byte_pos, _)| byte_pos).collect(); + + // Find byte offset using find + if let Some(byte_pos) = s.find(substr) { + // Find the character position (1-based) + // char_positions contains the byte offset for each character + // We need to find which character index corresponds to this byte offset + for (char_idx, &char_byte_pos) in char_positions.iter().enumerate() { + if char_byte_pos == byte_pos { + return (char_idx + 1) as i32; + } + } + // Fallback: if exact match not found, estimate + char_positions.len() as i32 + 1 + } else { + 0 + } +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -211,4 +238,39 @@ mod test { ); Ok(()) } + + #[test] + fn test_spark_instr_utf8() -> Result<()> { + // Test UTF-8 multi-byte characters + // "你好世界" - "世界" should return 3 (character position), not 6 (byte + // position) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("你好世界".to_string()), + Some("hello世界".to_string()), + Some("test".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("世界")), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(3), Some(6), Some(0),] + ); + + // Test with emoji (4-byte UTF-8) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello😀world".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::from("😀")), + ])?; + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(6),] + ); + + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index ef0983a58..5b26c9838 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -868,6 +868,117 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { |""".stripMargin checkSparkAnswerAndOperator(query) } + + test("instr function - basic functionality") { + withTable("t1") { + sql(""" + CREATE TABLE t1(str STRING, substr STRING) USING parquet + """) + sql(""" + INSERT INTO t1 VALUES + ('hello world', 'world'), + ('hello world', 'hello'), + ('hello world', 'o'), + ('hello world', 'z'), + (null, 'test'), + ('test', null) + """) + + // Test basic instr functionality + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - empty substring") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('hello'), ('world'), ('')") + + // Empty substring should return 0 + checkSparkAnswerAndOperator("SELECT instr(str, '') FROM t1") + } + } + + test("instr function - UTF-8 multi-byte characters") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('你好世界', '世界'), + ('hello世界', '世界'), + ('test', '世界'), + ('hello😀world', '😀'), + ('test😀', '😀') + """) + + // Test UTF-8 character position (not byte position) + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - with expressions") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('banana', 'a'), ('testtesttest', 'test'), ('abcabcabc', 'abc')") + + // Test with array column as substring (element-wise) + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - case sensitivity") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('Hello', 'hello'), + ('HELLO', 'hello'), + ('Hello', 'Hello'), + ('hElLo', 'hello') + """) + + // Instr is case-sensitive + checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1") + } + } + + test("instr function - in filter clause") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('hello world', 'world'), + ('hello', 'world'), + ('testing', 'test'), + ('abc', 'def') + """) + + // Test instr in WHERE clause + checkSparkAnswerAndOperator(""" + SELECT str FROM t1 WHERE instr(str, substr) > 0 + """) + } + } + + test("instr function - with grouping") { + withTable("t1") { + sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet") + sql(""" + INSERT INTO t1 VALUES + ('test1', 'test'), + ('test2', 'test'), + ('hello', 'world'), + ('testing', 'test') + """) + + // Test instr in GROUP BY + checkSparkAnswerAndOperator(""" + SELECT substr, COUNT(*) as cnt + FROM t1 + WHERE instr(str, substr) > 0 + GROUP BY substr + ORDER BY substr + """) } } }