Skip to content

Commit 094d824

Browse files
committed
Fix
1 parent 8fa388c commit 094d824

3 files changed

Lines changed: 185 additions & 9 deletions

File tree

auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@ class AuronInstrSuite extends QueryTest with SparkQueryTestsBase {
2929
)
3030

3131
val df = spark.createDataFrame(data).toDF("str", "substr")
32-
val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0))
33-
34-
assert(result(0) == 7, "instr('hello world', 'world') should return 7")
35-
assert(result(1) == 1, "instr('hello world', 'hello') should return 1")
36-
assert(result(2) == 5, "instr('hello world', 'o') should return 5")
37-
assert(result(3) == 0, "instr('hello world', 'z') should return 0")
38-
assert(result(4) == 0, "instr(null, 'test') should return null")
39-
assert(result(5) == 0, "instr('test', null) should return null")
32+
val rows = df.selectExpr("instr(str, substr)").collect()
33+
34+
// Check non-null results
35+
assert(rows(0).getInt(0) == 7, "instr('hello world', 'world') should return 7")
36+
assert(rows(1).getInt(0) == 1, "instr('hello world', 'hello') should return 1")
37+
assert(rows(2).getInt(0) == 5, "instr('hello world', 'o') should return 5")
38+
assert(rows(3).getInt(0) == 0, "instr('hello world', 'z') should return 0")
39+
40+
// Check null results
41+
assert(rows(4).isNullAt(0), "instr(null, 'test') should return null")
42+
assert(rows(5).isNullAt(0), "instr('test', null) should return null")
4043
}
4144

4245
test("test instr function - multiple occurrences") {

native-engine/datafusion-ext-functions/src/spark_instr.rs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
7070
if substr.is_empty() {
7171
Some(0)
7272
} else {
73-
Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))
73+
Some(find_char_position(s, substr))
7474
}
7575
}
7676
}),
@@ -88,6 +88,33 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
8888
}
8989
}
9090

91+
/// Find the 1-based character position of substr in s
92+
/// Returns 0 if not found
93+
fn find_char_position(s: &str, substr: &str) -> i32 {
94+
if substr.is_empty() {
95+
return 0;
96+
}
97+
98+
// Use char_indices to get byte offset to char position mapping
99+
let char_positions: Vec<usize> = s.char_indices().map(|(byte_pos, _)| byte_pos).collect();
100+
101+
// Find byte offset using find
102+
if let Some(byte_pos) = s.find(substr) {
103+
// Find the character position (1-based)
104+
// char_positions contains the byte offset for each character
105+
// We need to find which character index corresponds to this byte offset
106+
for (char_idx, &char_byte_pos) in char_positions.iter().enumerate() {
107+
if char_byte_pos == byte_pos {
108+
return (char_idx + 1) as i32;
109+
}
110+
}
111+
// Fallback: if exact match not found, estimate
112+
char_positions.len() as i32 + 1
113+
} else {
114+
0
115+
}
116+
}
117+
91118
#[cfg(test)]
92119
mod test {
93120
use std::sync::Arc;
@@ -211,4 +238,39 @@ mod test {
211238
);
212239
Ok(())
213240
}
241+
242+
#[test]
243+
fn test_spark_instr_utf8() -> Result<()> {
244+
// Test UTF-8 multi-byte characters
245+
// "你好世界" - "世界" should return 3 (character position), not 6 (byte
246+
// position)
247+
let r = spark_instr(&vec![
248+
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![
249+
Some("你好世界".to_string()),
250+
Some("hello世界".to_string()),
251+
Some("test".to_string()),
252+
]))),
253+
ColumnarValue::Scalar(ScalarValue::from("世界")),
254+
])?;
255+
let s = r.into_array(3)?;
256+
assert_eq!(
257+
as_int32_array(&s)?.into_iter().collect::<Vec<_>>(),
258+
vec![Some(3), Some(6), Some(0),]
259+
);
260+
261+
// Test with emoji (4-byte UTF-8)
262+
let r = spark_instr(&vec![
263+
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some(
264+
"hello😀world".to_string(),
265+
)]))),
266+
ColumnarValue::Scalar(ScalarValue::from("😀")),
267+
])?;
268+
let s = r.into_array(1)?;
269+
assert_eq!(
270+
as_int32_array(&s)?.into_iter().collect::<Vec<_>>(),
271+
vec![Some(6),]
272+
);
273+
274+
Ok(())
275+
}
214276
}

spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,117 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite {
868868
|""".stripMargin
869869
checkSparkAnswerAndOperator(query)
870870
}
871+
872+
test("instr function - basic functionality") {
873+
withTable("t1") {
874+
sql("""
875+
CREATE TABLE t1(str STRING, substr STRING) USING parquet
876+
""")
877+
sql("""
878+
INSERT INTO t1 VALUES
879+
('hello world', 'world'),
880+
('hello world', 'hello'),
881+
('hello world', 'o'),
882+
('hello world', 'z'),
883+
(null, 'test'),
884+
('test', null)
885+
""")
886+
887+
// Test basic instr functionality
888+
checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1")
889+
}
890+
}
891+
892+
test("instr function - empty substring") {
893+
withTable("t1") {
894+
sql("CREATE TABLE t1(str STRING) USING parquet")
895+
sql("INSERT INTO t1 VALUES ('hello'), ('world'), ('')")
896+
897+
// Empty substring should return 0
898+
checkSparkAnswerAndOperator("SELECT instr(str, '') FROM t1")
899+
}
900+
}
901+
902+
test("instr function - UTF-8 multi-byte characters") {
903+
withTable("t1") {
904+
sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet")
905+
sql("""
906+
INSERT INTO t1 VALUES
907+
('你好世界', '世界'),
908+
('hello世界', '世界'),
909+
('test', '世界'),
910+
('hello😀world', '😀'),
911+
('test😀', '😀')
912+
""")
913+
914+
// Test UTF-8 character position (not byte position)
915+
checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1")
916+
}
917+
}
918+
919+
test("instr function - with expressions") {
920+
withTable("t1") {
921+
sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet")
922+
sql("INSERT INTO t1 VALUES ('banana', 'a'), ('testtesttest', 'test'), ('abcabcabc', 'abc')")
923+
924+
// Test with array column as substring (element-wise)
925+
checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1")
926+
}
927+
}
928+
929+
test("instr function - case sensitivity") {
930+
withTable("t1") {
931+
sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet")
932+
sql("""
933+
INSERT INTO t1 VALUES
934+
('Hello', 'hello'),
935+
('HELLO', 'hello'),
936+
('Hello', 'Hello'),
937+
('hElLo', 'hello')
938+
""")
939+
940+
// Instr is case-sensitive
941+
checkSparkAnswerAndOperator("SELECT instr(str, substr) FROM t1")
942+
}
943+
}
944+
945+
test("instr function - in filter clause") {
946+
withTable("t1") {
947+
sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet")
948+
sql("""
949+
INSERT INTO t1 VALUES
950+
('hello world', 'world'),
951+
('hello', 'world'),
952+
('testing', 'test'),
953+
('abc', 'def')
954+
""")
955+
956+
// Test instr in WHERE clause
957+
checkSparkAnswerAndOperator("""
958+
SELECT str FROM t1 WHERE instr(str, substr) > 0
959+
""")
960+
}
961+
}
962+
963+
test("instr function - with grouping") {
964+
withTable("t1") {
965+
sql("CREATE TABLE t1(str STRING, substr STRING) USING parquet")
966+
sql("""
967+
INSERT INTO t1 VALUES
968+
('test1', 'test'),
969+
('test2', 'test'),
970+
('hello', 'world'),
971+
('testing', 'test')
972+
""")
973+
974+
// Test instr in GROUP BY
975+
checkSparkAnswerAndOperator("""
976+
SELECT substr, COUNT(*) as cnt
977+
FROM t1
978+
WHERE instr(str, substr) > 0
979+
GROUP BY substr
980+
ORDER BY substr
981+
""")
871982
}
872983
}
873984
}

0 commit comments

Comments
 (0)