Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,26 @@ abstract class InMemoryBaseTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
// the result should be consistent with BucketFunctions defined at transformFunctions.scala
case BucketTransform(numBuckets, cols, _) =>
val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
var valueHashCode = 0
valueTypePairs.foreach( pair =>
if ( pair._1 != null) valueHashCode += pair._1.hashCode()
)
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
val hash: Long = cols.foldLeft(0L) { (acc, col) =>
val valueHash = extractor(col.fieldNames, cleanedSchema, row) match {
case (value: Byte, _: ByteType) => value.toLong
case (value: Short, _: ShortType) => value.toLong
case (value: Int, _: IntegerType) => value.toLong
case (value: Long, _: LongType) => value
case (value: Long, _: TimestampType) => value
case (value: Long, _: TimestampNTZType) => value
case (value: UTF8String, _: StringType) =>
value.hashCode.toLong
case (value: Array[Byte], BinaryType) =>
util.Arrays.hashCode(value).toLong
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
acc + valueHash
}
Math.floorMod(hash, numBuckets)
case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (str: UTF8String, StringType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] {
def getOutputKeyGroupedPartitioning(
basePartitioning: KeyGroupedPartitioning,
spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = {
val expressions = spjParams.joinKeyPositions match {
val projectedExpressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) =>
projectionPositions.map(i => basePartitioning.expressions(i))
case _ => basePartitioning.expressions
Expand All @@ -52,16 +52,16 @@ trait KeyGroupedPartitionedScan[T] {
case Some(projectionPositions) =>
val internalRowComparableWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
expressions.map(_.dataType))
projectedExpressions.map(_.dataType))
basePartitioning.partitionValues.map { r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions,
Copy link
Contributor

@peter-toth peter-toth Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the wrong projected excepression is the root cause of the ArrayIndexOutOfBoundsException you hit and passing in basePartitioning.expressions looks good.

But the test you added will unlikely pass as there is an issue with the test framework.
I left a note here:

// Do not use `bucket()` in "one side partition" tests as its implementation in
// `InMemoryBaseTable` conflicts with `BucketFunction`
, but forgot to open a fix for the problem with using bucket() in these one side shuffle tests.

The problem is that the bucket() implementation here:

override def produceResult(input: InternalRow): Int = {
(input.getLong(1) % input.getInt(0)).toInt
}

and in InMemoryBaseTable:
val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
var valueHashCode = 0
valueTypePairs.foreach( pair =>
if ( pair._1 != null) valueHashCode += pair._1.hashCode()
)
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
mismatch.
So technically the partition keys that the datasource reports and the calculated key of the partition where the partitioner puts the shuffled records don't match.

@pan3793, could you please keep your fix in KeyGroupedPartitionedScan.scala‎ and fix the BucketTransform key calculation in InMemoryBaseTable?
You don't need need the other changes. originalPartitionValues seems unrelated as it is used only when partially clustered distribution is enabled.

BTW, I'm working on refactoring SPJ based on this idea: #53859 (comment) and it looks prosmising so far, but I need some more days to wrap it up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Do not use bucket() in "one side partition" tests as its implementation in
// InMemoryBaseTable conflicts with BucketFunction

Oh, god, @peter-toth, thanks a lot for pointing this out, I wasn't aware of it and have spent a few hours trying to figure out why SMJ partition key value mismatch and produce wrong result after fixing the ArrayIndexOutOfBoundsException ...

Actually, the current code changes are just a draft; the test cases have not yet passed. I will try to fix it following your guidance. Thank you again, @peter-toth!

projectionPositions, r)
internalRowComparableWrapperFactory(projectedRow)
}.distinct.map(_.row)
case _ => basePartitioning.partitionValues
}
}
basePartitioning.copy(expressions = expressions, numPartitions = newPartValues.length,
basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32))))

// Has exactly one partition.
val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v)))
checkQueryPlan(df, distribution,
physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues))
}
Expand Down Expand Up @@ -2798,8 +2798,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}

test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") {
// Do not use `bucket()` in "one side partition" tests as its implementation in
// `InMemoryBaseTable` conflicts with `BucketFunction`
val items_partitions = Array(years("arrive_time"))
createTable(items, itemsColumns, items_partitions)

Expand Down Expand Up @@ -2841,4 +2839,42 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
assert(metrics("number of rows read") == "3")
}

test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " +
"are less than cluster keys") {
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {

val customers_partitions = Array(identity("customer_name"), bucket(4, "customer_id"))
createTable(customers, customersColumns, customers_partitions)
sql(s"INSERT INTO testcat.ns.$customers VALUES " +
s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)")

createTable(orders, ordersColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$orders VALUES " +
s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)")

val df = sql(
s"""${selectWithMergeJoinHint("c", "o")}
|customer_name, customer_age, order_amount
|FROM testcat.ns.$customers c JOIN testcat.ns.$orders o
|ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount
|""".stripMargin)

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.length == 1)

checkAnswer(df, Seq(
Row("aaa", 10, 100.0),
Row("aaa", 10, 200.0),
Row("bbb", 20, 150.0),
Row("bbb", 20, 250.0),
Row("bbb", 20, 350.0),
Row("ccc", 30, 400.50)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -56,7 +56,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).select("index", "data", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1")))
checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1, "a", "1/1")))
}
}
}
Expand Down Expand Up @@ -125,7 +125,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {

checkAnswer(
dfQuery,
Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))
)
}
}
Expand All @@ -135,7 +135,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
prepareTable()
checkAnswer(
spark.table(tbl).select("id", "data").select("index", "_partition"),
Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3"))
Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3"))
)
}
}
Expand All @@ -160,7 +160,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -172,7 +172,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -186,7 +186,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
.select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}
}
}
Expand All @@ -201,7 +201,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")))
}

// Metadata columns are propagated through SubqueryAlias even if child is not a leaf node.
Expand Down Expand Up @@ -394,7 +394,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase {
withTable(tbl) {
sql(s"CREATE TABLE $tbl (id bigint, data char(1)) PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')")
val expected = Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))
val expected = Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))

// Unqualified column access
checkAnswer(sql(s"SELECT id, data, index, _partition FROM $tbl"), expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ object UnboundBucketFunction extends UnboundFunction {
override def name(): String = "bucket"
}

// the result should be consistent with BucketTransform defined at InMemoryBaseTable.scala
object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] {
override def inputTypes(): Array[DataType] = Array(IntegerType, LongType)
override def resultType(): DataType = IntegerType
override def name(): String = "bucket"
override def canonicalName(): String = name()
override def toString: String = name()
override def produceResult(input: InternalRow): Int = {
(input.getLong(1) % input.getInt(0)).toInt
Math.floorMod(input.getLong(1), input.getInt(0))
}

override def reducer(
Expand Down