diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index ebb4eef80f15f..407d592f82199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -236,15 +236,26 @@ abstract class InMemoryBaseTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } + // the result should be consistent with BucketFunctions defined at transformFunctions.scala case BucketTransform(numBuckets, cols, _) => - val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) - var valueHashCode = 0 - valueTypePairs.foreach( pair => - if ( pair._1 != null) valueHashCode += pair._1.hashCode() - ) - var dataTypeHashCode = 0 - valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) - ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets + val hash: Long = cols.foldLeft(0L) { (acc, col) => + val valueHash = extractor(col.fieldNames, cleanedSchema, row) match { + case (value: Byte, _: ByteType) => value.toLong + case (value: Short, _: ShortType) => value.toLong + case (value: Int, _: IntegerType) => value.toLong + case (value: Long, _: LongType) => value + case (value: Long, _: TimestampType) => value + case (value: Long, _: TimestampNTZType) => value + case (value: UTF8String, _: StringType) => + value.hashCode.toLong + case (value: Array[Byte], BinaryType) => + util.Arrays.hashCode(value).toLong + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + acc + valueHash + } + Math.floorMod(hash, numBuckets) case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) => extractor(ref.fieldNames, cleanedSchema, row) match { case (str: UTF8String, StringType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index 10a6aaa2e1851..cac4a9bc852f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] { def getOutputKeyGroupedPartitioning( basePartitioning: KeyGroupedPartitioning, spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = { - val expressions = spjParams.joinKeyPositions match { + val projectedExpressions = spjParams.joinKeyPositions match { case Some(projectionPositions) => projectionPositions.map(i => basePartitioning.expressions(i)) case _ => basePartitioning.expressions @@ -52,16 +52,16 @@ trait KeyGroupedPartitionedScan[T] { case Some(projectionPositions) => val internalRowComparableWrapperFactory = InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - expressions.map(_.dataType)) + projectedExpressions.map(_.dataType)) basePartitioning.partitionValues.map { r => - val projectedRow = KeyGroupedPartitioning.project(expressions, + val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions, projectionPositions, r) internalRowComparableWrapperFactory(projectedRow) }.distinct.map(_.row) case _ => basePartitioning.partitionValues } } - basePartitioning.copy(expressions = expressions, numPartitions = newPartValues.length, + basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, partitionValues = newPartValues) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7c07d08d80af8..8cd55304d71c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -124,7 +124,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(TransformExpression(BucketFunction, Seq(attr("ts")), Some(32)))) // Has exactly one partition. - val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v))) + val partitionValues = Seq(0).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) } @@ -2798,8 +2798,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } test("SPARK-54439: KeyGroupedPartitioning with transform and join key size mismatch") { - // Do not use `bucket()` in "one side partition" tests as its implementation in - // `InMemoryBaseTable` conflicts with `BucketFunction` val items_partitions = Array(years("arrive_time")) createTable(items, itemsColumns, items_partitions) @@ -2841,4 +2839,42 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } assert(metrics("number of rows read") == "3") } + + test("SPARK-55411: Fix ArrayIndexOutOfBoundsException when join keys " + + "are less than cluster keys") { + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + + val customers_partitions = Array(identity("customer_name"), bucket(4, "customer_id")) + createTable(customers, customersColumns, customers_partitions) + sql(s"INSERT INTO testcat.ns.$customers VALUES " + + s"('aaa', 10, 1), ('bbb', 20, 2), ('ccc', 30, 3)") + + createTable(orders, ordersColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$orders VALUES " + + s"(100.0, 1), (200.0, 1), (150.0, 2), (250.0, 2), (350.0, 2), (400.50, 3)") + + val df = sql( + s"""${selectWithMergeJoinHint("c", "o")} + |customer_name, customer_age, order_amount + |FROM testcat.ns.$customers c JOIN testcat.ns.$orders o + |ON c.customer_id = o.customer_id ORDER BY c.customer_id, order_amount + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.length == 1) + + checkAnswer(df, Seq( + Row("aaa", 10, 100.0), + Row("aaa", 10, 200.0), + Row("bbb", 20, 150.0), + Row("bbb", 20, 250.0), + Row("bbb", 20, 350.0), + Row("ccc", 30, 400.50))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala index 3bfd57e867c07..fe338175ec888 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -42,7 +42,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -56,7 +56,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).select("index", "data", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(3, "c", "1/3"), Row(2, "b", "0/2"), Row(1, "a", "3/1"))) + checkAnswer(query, Seq(Row(3, "c", "3/3"), Row(2, "b", "2/2"), Row(1, "a", "1/1"))) } } } @@ -125,7 +125,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { checkAnswer( dfQuery, - Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) + Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")) ) } } @@ -135,7 +135,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { prepareTable() checkAnswer( spark.table(tbl).select("id", "data").select("index", "_partition"), - Seq(Row(0, "3/1"), Row(0, "0/2"), Row(0, "1/3")) + Seq(Row(0, "1/1"), Row(0, "2/2"), Row(0, "3/3")) ) } } @@ -160,7 +160,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).where("id > 1").select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -172,7 +172,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { val dfQuery = spark.table(tbl).orderBy("id").select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -186,7 +186,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { .select("id", "data", "index", "_partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } } } @@ -201,7 +201,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") Seq(sqlQuery, dfQuery).foreach { query => - checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + checkAnswer(query, Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3"))) } // Metadata columns are propagated through SubqueryAlias even if child is not a leaf node. @@ -394,7 +394,7 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { withTable(tbl) { sql(s"CREATE TABLE $tbl (id bigint, data char(1)) PARTITIONED BY (bucket(4, id), id)") sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") - val expected = Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")) + val expected = Seq(Row(1, "a", 0, "1/1"), Row(2, "b", 0, "2/2"), Row(3, "c", 0, "3/3")) // Unqualified column access checkAnswer(sql(s"SELECT id, data, index, _partition FROM $tbl"), expected) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index b82cc2392e1fc..ed2f81d7e8d6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -84,6 +84,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } +// the result should be consistent with BucketTransform defined at InMemoryBaseTable.scala object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType @@ -91,7 +92,7 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In override def canonicalName(): String = name() override def toString: String = name() override def produceResult(input: InternalRow): Int = { - (input.getLong(1) % input.getInt(0)).toInt + Math.floorMod(input.getLong(1), input.getInt(0)) } override def reducer(