From cb78da6df2935f25225e741d54311b2be9f211a6 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sun, 8 Feb 2026 13:17:38 +0800 Subject: [PATCH 1/4] [SPARK-55411][SQL] SPJ may throw ArrayIndexOutOfBoundsException when join keys are less than cluster keys --- .../connector/catalog/InMemoryBaseTable.scala | 29 +++++++++---- .../execution/KeyGroupedPartitionedScan.scala | 8 ++-- .../KeyGroupedPartitioningSuite.scala | 42 +++++++++++++++++-- .../functions/transformFunctions.scala | 1 + 4 files changed, 65 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index ebb4eef80f15f..525a29bbc67aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -236,15 +236,28 @@ abstract class InMemoryBaseTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } + // keep the logic consistent with BucketFunctions defined at transformFunctions.scala case BucketTransform(numBuckets, cols, _) => - val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) - var valueHashCode = 0 - valueTypePairs.foreach( pair => - if ( pair._1 != null) valueHashCode += pair._1.hashCode() - ) - var dataTypeHashCode = 0 - valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) - ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets + if (cols.length != 1) { + throw new IllegalArgumentException( + s"Match: bucket transform only supports 1 argument, but got ${cols.length}") + } + extractor(cols.head.fieldNames, cleanedSchema, row) match { + case (value: Byte, _: ByteType) => + (value.toLong % numBuckets).toInt + case (value: Short, _: ShortType) => + (value.toLong % numBuckets).toInt + case (value: Int, _: IntegerType) => + (value.toLong % numBuckets).toInt + case (value: Long, _: LongType) => + (value % numBuckets).toInt + case (value: Long, _: TimestampType) => + (value % numBuckets).toInt + case (value: Long, _: TimestampNTZType) => + (value % numBuckets).toInt + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) => extractor(ref.fieldNames, cleanedSchema, row) match { case (str: UTF8String, StringType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 10a6aaa2e1851..cac4a9bc852f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] { def getOutputKeyGroupedPartitioning( basePartitioning: KeyGroupedPartitioning, spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = { - val expressions = spjParams.joinKeyPositions match { + val projectedExpressions = spjParams.joinKeyPositions match { case Some(projectionPositions) => projectionPositions.map(i => basePartitioning.expressions(i)) case _ => basePartitioning.expressions @@ -52,16 +52,16 @@ trait KeyGroupedPartitionedScan[T] { case Some(projectionPositions) => val internalRowComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - expressions.map(_.dataType)) + projectedExpressions.map(_.dataType)) basePartitioning.partitionValues.map { r => - val projectedRow = KeyGroupedPartitioning.project(expressions, + val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions, projectionPositions, r) internalRowComparableWrapperFactory(projectedRow) }.distinct.map(_.row) case _ => basePartitioning.partitionValues } } - basePartitioning.copy(expressions = expressions, numPartitions = newPartValues.length, + basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, partitionValues = newPartValues) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7c07d08d80af8..8cd55304d71c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -124,7 +124,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) // Has exactly one partition. - val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v))) + val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) } @@ -2798,8 +2798,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") { - // Do not use `bucket()` in "one side partition" tests as its implementation in - // `InMemoryBaseTable` conflicts with `BucketFunction` val items_partitions = Array(years("arrive_time")) createTable(items, itemsColumns, items_partitions) @@ -2841,4 +2839,42 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } assert(metrics("number of rows read") == "3") } + + test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " + + "are less than cluster keys") { + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + + val customers_partitions = Array(identity("customer_name"), bucket(4, "customer_id")) + createTable(customers, customersColumns, customers_partitions) + sql(s"INSERT INTO testcat.ns.$customers VALUES " + + s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)") + + createTable(orders, ordersColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$orders VALUES " + + s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)") + + val df = sql( + s"""${selectWithMergeJoinHint("c", "o")} + |customer_name, customer_age, order_amount + |FROM testcat.ns.$customers c JOIN testcat.ns.$orders o + |ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.length == 1) + + checkAnswer(df, Seq( + Row("aaa", 10, 100.0), + Row("aaa", 10, 200.0), + Row("bbb", 20, 150.0), + Row("bbb", 20, 250.0), + Row("bbb", 20, 350.0), + Row("ccc", 30, 400.50))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index b82cc2392e1fc..f213f1d27bbf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -84,6 +84,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } +// keep the logic consistent with BucketTransform defined at InMemoryBaseTable.scala object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType From beade2bd6f9cddc288bd345675584300bc76418a Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sun, 8 Feb 2026 16:10:09 +0800 Subject: [PATCH 2/4] fix test related to bucket calc change --- .../connector/catalog/InMemoryBaseTable.scala | 38 +++++++++---------- .../sql/connector/MetadataColumnSuite.scala | 18 ++++----- .../functions/transformFunctions.scala | 2 +- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 525a29bbc67aa..63690b11e1d0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -236,28 +236,26 @@ abstract class InMemoryBaseTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - // keep the logic consistent with BucketFunctions defined at transformFunctions.scala + // the result should be consistent with BucketFunctions defined at transformFunctions.scala case BucketTransform(numBuckets, cols, _) => - if (cols.length != 1) { - throw new IllegalArgumentException( - s"Match: bucket transform only supports 1 argument, but got ${cols.length}") - } - extractor(cols.head.fieldNames, cleanedSchema, row) match { - case (value: Byte, _: ByteType) => - (value.toLong % numBuckets).toInt - case (value: Short, _: ShortType) => - (value.toLong % numBuckets).toInt - case (value: Int, _: IntegerType) => - (value.toLong % numBuckets).toInt - case (value: Long, _: LongType) => - (value % numBuckets).toInt - case (value: Long, _: TimestampType) => - (value % numBuckets).toInt - case (value: Long, _: TimestampNTZType) => - (value % numBuckets).toInt - case (v, t) => - throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + val hash: Long = cols.foldLeft(0L) { (acc, col) => + val valueHash = extractor(col.fieldNames, cleanedSchema, row) match { + case (value: Byte, _: ByteType) => value.toLong + case (value: Short, _: ShortType) => value.toLong + case (value: Int, _: IntegerType) => value.toLong + case (value: Long, _: LongType) => value + case (value: Long, _: TimestampType) => value + case (value: Long, _: TimestampNTZType) => value + case (value: UTF8String, _: StringType) => + value.hashCode.toLong + case (value: Array[Byte], BinaryType) => + util.Arrays.hashCode(value).toLong + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + (acc + valueHash) & 0xFFFFFFFFFFFFL } + (hash % numBuckets).toInt case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) => extractor(ref.fieldNames, cleanedSchema, row) match { case (str: UTF8String, StringType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala index 3bfd57e867c07..fe338175ec888 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -42,7 +42,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -56,7 +56,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).select("index", "data", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) + checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1, "a", "1/1"))) } } } @@ -125,7 +125,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { checkAnswer( dfQuery, - Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) + Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")) ) } } @@ -135,7 +135,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { prepareTable() checkAnswer( spark.table(tbl).select("id", "data").select("index", "_partition"), - Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3")) + Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3")) ) } } @@ -160,7 +160,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -172,7 +172,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -186,7 +186,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { .select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -201,7 +201,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } // Metadata columns are propagated through SubqueryAlias even if child is not a leaf node. @@ -394,7 +394,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { withTable(tbl) { sql(s"CREATE TABLE $tbl (id bigint, data char(1)) PARTITIONED BY (bucket(4, id), id)") sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") - val expected = Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) + val expected = Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")) // Unqualified column access checkAnswer(sql(s"SELECT id, data, index, _partition FROM $tbl"), expected) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index f213f1d27bbf5..f0370c4b96eae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -84,7 +84,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -// keep the logic consistent with BucketTransform defined at InMemoryBaseTable.scala +// the result should be consistent with BucketTransform defined at InMemoryBaseTable.scala object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType From 55bbd701ba8ba4c060309679dd12a54c26908a1c Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sun, 8 Feb 2026 16:33:30 +0800 Subject: [PATCH 3/4] ensure positive --- .../sql/connector/catalog/functions/transformFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index f0370c4b96eae..b23d2005fabdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -92,7 +92,7 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In override def canonicalName(): String = name() override def toString: String = name() override def produceResult(input: InternalRow): Int = { - (input.getLong(1) % input.getInt(0)).toInt + ((input.getLong(1) & 0xFFFFFFFFFFFFL) % input.getInt(0)).toInt } override def reducer( From 216352ca42090468cc819a1735985141899d4dab Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sun, 8 Feb 2026 20:46:24 +0800 Subject: [PATCH 4/4] floorMod --- .../spark/sql/connector/catalog/InMemoryBaseTable.scala | 4 ++-- .../sql/connector/catalog/functions/transformFunctions.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 63690b11e1d0e..407d592f82199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -253,9 +253,9 @@ abstract class InMemoryBaseTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - (acc + valueHash) & 0xFFFFFFFFFFFFL + acc + valueHash } - (hash % numBuckets).toInt + Math.floorMod(hash, numBuckets) case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) => extractor(ref.fieldNames, cleanedSchema, row) match { case (str: UTF8String, StringType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index b23d2005fabdf..ed2f81d7e8d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -92,7 +92,7 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In override def canonicalName(): String = name() override def toString: String = name() override def produceResult(input: InternalRow): Int = { - ((input.getLong(1) & 0xFFFFFFFFFFFFL) % input.getInt(0)).toInt + Math.floorMod(input.getLong(1), input.getInt(0)) } override def reducer(