Skip to content

[SPARK-55411][SQL] SPJ may throw ArrayIndexOutOfBoundsException when join keys are less than cluster keys#54182

Open
pan3793 wants to merge 4 commits intoapache:masterfrom
pan3793:spj-subset-joinkey-bug
Open

[SPARK-55411][SQL] SPJ may throw ArrayIndexOutOfBoundsException when join keys are less than cluster keys#54182
pan3793 wants to merge 4 commits intoapache:masterfrom
pan3793:spj-subset-joinkey-bug

Conversation

@pan3793
Copy link
Member

@pan3793 pan3793 commented Feb 6, 2026

What changes were proposed in this pull request?

Fix a java.lang.ArrayIndexOutOfBoundsException when spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true, by correcting the expression(should pass the full partition expression instead of the projected one) passed to KeyGroupedPartitioning#project.

Also, fix a test code issue, change the calculation result of BucketTransform defined at InMemoryBaseTable.scala to match BucketFunctions defined at transformFunctions.scala (thanks @peter-toth for pointing out this!)

Why are the changes needed?

It's a bug fix.

Does this PR introduce any user-facing change?

Some queries that failed when spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled=true now run normally.

How was this patch tested?

New UT is added, previously it failed with ArrayIndexOutOfBoundsException, now passed.

$ build/sbt "sql/testOnly *KeyGroupedPartitioningSuite -- -z SPARK=55411"
...
[info] - bug *** FAILED *** (1 second, 884 milliseconds)
[info]   java.lang.ArrayIndexOutOfBoundsException: Index 1 out of bounds for length 1
[info]   at scala.collection.immutable.ArraySeq$ofRef.apply(ArraySeq.scala:331)
[info]   at org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1(partitioning.scala:471)
[info]   at org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.$anonfun$project$1$adapted(partitioning.scala:471)
[info]   at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
[info]   at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
[info]   at org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning$.project(partitioning.scala:471)
[info]   at org.apache.spark.sql.execution.KeyGroupedPartitionedScan.$anonfun$getOutputKeyGroupedPartitioning$5(KeyGroupedPartitionedScan.scala:58)
...

UTs affected by bucket() calculate logic change are tuned.

Was this patch authored or co-authored using generative AI tooling?

No.

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

JIRA Issue Information

=== Bug SPARK-55411 ===
Summary: SPJ may throw ArrayIndexOutOfBoundsException when join keys are less than cluster keys
Assignee: None
Status: Open
Affected: ["4.0.2","4.1.1"]


This comment was automatically generated by GitHub Actions

@github-actions github-actions bot added the SQL label Feb 6, 2026
@szehon-ho
Copy link
Member

thanks for the repo, ill try to take a look.

partitioning.numPartitions,
partitioning.partitionValues)
partitioning.partitionValues,
partitioning.originalPartitionValues)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found originalPartitionValues is not always populated. is it intentional?

@pan3793 pan3793 changed the title [SPARK-XXXXX][SQL] Internel error when SPJ ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS enabled [SPARK-55411][SQL] SPJ may throw ArrayIndexOutOfBoundsException when join keys are less than cluster keys Feb 7, 2026
projectedExpressions.map(_.dataType))
basePartitioning.partitionValues.map { r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions,
Copy link
Contributor

@peter-toth peter-toth Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the wrong projected excepression is the root cause of the ArrayIndexOutOfBoundsException you hit and passing in basePartitioning.expressions looks good.

But the test you added will unlikely pass as there is an issue with the test framework.
I left a note here:

// Do not use `bucket()` in "one side partition" tests as its implementation in
// `InMemoryBaseTable` conflicts with `BucketFunction`
, but forgot to open a fix for the problem with using bucket() in these one side shuffle tests.

The problem is that the bucket() implementation here:

override def produceResult(input: InternalRow): Int = {
(input.getLong(1) % input.getInt(0)).toInt
}

and in InMemoryBaseTable:
val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
var valueHashCode = 0
valueTypePairs.foreach( pair =>
if ( pair._1 != null) valueHashCode += pair._1.hashCode()
)
var dataTypeHashCode = 0
valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
mismatch.
So technically the partition keys that the datasource reports and the calculated key of the partition where the partitioner puts the shuffled records don't match.

@pan3793, could you please keep your fix in KeyGroupedPartitionedScan.scala‎ and fix the BucketTransform key calculation in InMemoryBaseTable?
You don't need need the other changes. originalPartitionValues seems unrelated as it is used only when partially clustered distribution is enabled.

BTW, I'm working on refactoring SPJ based on this idea: #53859 (comment) and it looks prosmising so far, but I need some more days to wrap it up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Do not use bucket() in "one side partition" tests as its implementation in
// InMemoryBaseTable conflicts with BucketFunction

Oh, god, @peter-toth, thanks a lot for pointing this out, I wasn't aware of it and have spent a few hours trying to figure out why SMJ partition key value mismatch and produce wrong result after fixing the ArrayIndexOutOfBoundsException ...

Actually, the current code changes are just a draft; the test cases have not yet passed. I will try to fix it following your guidance. Thank you again, @peter-toth!

@pan3793 pan3793 force-pushed the spj-subset-joinkey-bug branch from bbf8c3b to cb78da6 Compare February 8, 2026 05:17
@pan3793 pan3793 marked this pull request as ready for review February 8, 2026 08:27
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
(acc + valueHash) & 0xFFFFFFFFFFFFL
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> Long.MaxValue + 1L
res0: Long = -9223372036854775808

scala> (Long.MaxValue + 1L) & 0xFFFFFFFFFFFFL
res1: Long = 0

scala> (Long.MaxValue + 2L) & 0xFFFFFFFFFFFFL
res2: Long = 1

Copy link
Contributor

@peter-toth peter-toth Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is needed because % N can return negative results, isn't it? That seems like problem at both places as bucket N should return max N different values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use Math.floorMod()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the bucket num should be >=1 (seems we don't have such a check though), then (non_negative_long % positive_int) should always be positive?

Copy link
Contributor

@peter-toth peter-toth Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's correct, but

override def produceResult(input: InternalRow): Int = {
(input.getLong(1) % input.getInt(0)).toInt
}
seems also wrong as it can return values between -N+1 and N-1 so we should probably fix both places. If we used Math.floorMod() then we wouldn't need that & 0xFFFFFFFFFFFFL non-negative conversion.

@peter-toth
Copy link
Contributor

Looks good to me, let's wait for CI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants