Skip to content

Commit e23797a

Browse files
committed
[SPARK-56190][SQL] Support nested partition columns for DSV2 PartitionPredicate
1 parent 8e88f5a commit e23797a

12 files changed

Lines changed: 295 additions & 89 deletions

File tree

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionColumnReference.java renamed to sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionFieldReference.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
import org.apache.spark.sql.connector.catalog.Table;
2222

2323
/**
24-
* A reference to a partition column in {@link Table#partitioning()}.
24+
* A reference to a partition field in {@link Table#partitioning()}.
2525
* <p>
26-
* {@link #fieldNames()} returns the partition column name (or names) as reported by
26+
* {@link #fieldNames()} returns the partition field name (or names) as reported by
2727
* the table's partition schema.
2828
* {@link #ordinal()} returns the 0-based position in {@link Table#partitioning()}.
2929
*
3030
* @since 4.2.0
3131
*/
3232
@Evolving
33-
public interface PartitionColumnReference extends NamedReference {
33+
public interface PartitionFieldReference extends NamedReference {
3434

3535
/**
3636
* Returns the 0-based ordinal of this partition column in {@link Table#partitioning()}.

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/PartitionPredicate.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import org.apache.spark.sql.catalyst.InternalRow;
2222
import org.apache.spark.sql.connector.catalog.Table;
2323
import org.apache.spark.sql.connector.expressions.NamedReference;
24-
import org.apache.spark.sql.connector.expressions.PartitionColumnReference;
25-
26-
import static org.apache.spark.sql.connector.expressions.Expression.EMPTY_EXPRESSION;
24+
import org.apache.spark.sql.connector.expressions.PartitionFieldReference;
2725

2826
/**
2927
* Represents a partition predicate that can be evaluated using {@link Table#partitioning()}.
@@ -47,17 +45,17 @@ protected PartitionPredicate() {
4745
/**
4846
* {@inheritDoc}
4947
* <p>
50-
* For PartitionPredicate, returns {@link PartitionColumnReference} instances that identify
48+
* For PartitionPredicate, returns {@link PartitionFieldReference} instances that identify
5149
* the partition columns (from {@link Table#partitioning()}) referenced by this predicate.
52-
* Each reference's {@link PartitionColumnReference#fieldNames()} gives the partition column
53-
* name; {@link PartitionColumnReference#ordinal()} gives the 0-based position in
50+
* Each reference's {@link PartitionFieldReference#fieldNames()} gives the partition column
51+
* name; {@link PartitionFieldReference#ordinal()} gives the 0-based position in
5452
* {@link Table#partitioning()}.
5553
* <p>
5654
* <b>Example:</b> Suppose {@code Table.partitioning()} returns three partition
5755
* transforms: {@code [years(ts), months(ts), bucket(32, id)]} with ordinals 0, 1, 2.
58-
* Each {@link PartitionColumnReference} has {@link PartitionColumnReference#fieldNames()}
56+
* Each {@link PartitionFieldReference} has {@link PartitionFieldReference#fieldNames()}
5957
* (the transform display name, e.g. {@code years(ts)}) and
60-
* {@link PartitionColumnReference#ordinal()}:
58+
* {@link PartitionFieldReference#ordinal()}:
6159
* <ul>
6260
* <li>{@code years(ts) = 2026} returns one reference: (fieldNames=[years(ts)], ordinal=0).</li>
6361
* <li>{@code years(ts) = 2026 and months(ts) = 01} returns two references:

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.connector.read;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.connector.expressions.PartitionColumnReference;
21+
import org.apache.spark.sql.connector.expressions.PartitionFieldReference;
2222
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate;
2323
import org.apache.spark.sql.connector.expressions.filter.Predicate;
2424

@@ -54,8 +54,8 @@ public interface SupportsPushDownV2Filters extends ScanBuilder {
5454
* {@link #pushedPredicates()} can return predicates from all of them.
5555
* <p>
5656
* For each {@link PartitionPredicate}, the implementation can use
57-
* {@link PartitionPredicate#references()} (each {@link PartitionColumnReference} has
58-
* {@link PartitionColumnReference#ordinal()}) to decide whether to return it for post-scan
57+
* {@link PartitionPredicate#references()} (each {@link PartitionFieldReference} has
58+
* {@link PartitionFieldReference#ordinal()}) to decide whether to return it for post-scan
5959
* filtering. For example, data sources with
6060
* partition spec evolution may return predicates that reference later-added partition
6161
* transforms (incompletely partitioned data) so Spark evaluates them after the scan, while

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionColumnReferenceImpl.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionFieldReferenceImpl.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
package org.apache.spark.sql.internal.connector
1919

20-
import org.apache.spark.sql.connector.expressions.PartitionColumnReference
20+
import org.apache.spark.sql.connector.expressions.PartitionFieldReference
2121

2222
/**
23-
* Implementation of [[PartitionColumnReference]] that carries the position ordinal in
24-
* Table.partitioning() and the partition column name(s) for that position.
23+
* Implementation of [[PartitionFieldReference]] that carries the position ordinal in
24+
* Table.partitioning() and the partition field name(s) for that position.
2525
*/
26-
private[connector] case class PartitionColumnReferenceImpl(
26+
private[connector] case class PartitionFieldReferenceImpl(
2727
ordinal: Int,
2828
fieldNames: Array[String])
29-
extends PartitionColumnReference
29+
extends PartitionFieldReference
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal.connector
19+
20+
import org.apache.spark.sql.connector.expressions.NamedReference
21+
import org.apache.spark.sql.types.StructField
22+
23+
/**
24+
* Metadata for one identity partition field.
25+
*
26+
* @param structField the Catalyst field describing the partition column (name, type,
27+
* nullability). For nested columns the name is the dotted path
28+
* (e.g. `"s.tz"`).
29+
* @param identityRef the [[NamedReference]] from the transform target column
30+
* (e.g. `FieldReference(Seq("s", "tz"))`).
31+
*/
32+
case class PartitionPredicateField(structField: StructField, identityRef: NamedReference)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImpl.scala

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.SparkException
2121
import org.apache.spark.internal.{Logging, LogKeys}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Expression => CatalystExpression, Predicate => CatalystPredicate}
24+
import org.apache.spark.sql.catalyst.types.DataTypeUtils
2425
import org.apache.spark.sql.connector.expressions.NamedReference
2526
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate
2627

@@ -30,30 +31,33 @@ import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate
3031
*/
3132
class PartitionPredicateImpl private (
3233
private val catalystExpr: CatalystExpression,
33-
private val partitionSchema: Seq[AttributeReference])
34+
private val partitionFields: Seq[PartitionPredicateField])
3435
extends PartitionPredicate with Logging {
3536

37+
@transient private lazy val partitionAttrs: Seq[AttributeReference] =
38+
partitionFields.map(c => DataTypeUtils.toAttribute(c.structField))
39+
3640
/** The wrapped partition filter Catalyst Expression. */
3741
def expression: CatalystExpression = catalystExpr
3842

3943
/** Bound predicate, computed once and reused for all partition rows. */
4044
@transient private lazy val boundPredicate: InternalRow => Boolean = {
4145
val boundExpr = catalystExpr.transform {
4246
case a: AttributeReference =>
43-
val index = partitionSchema.indexWhere(_.name == a.name)
47+
val index = partitionAttrs.indexWhere(_.name == a.name)
4448
require(index >= 0, s"Column ${a.name} not found in partition schema")
45-
BoundReference(index, partitionSchema(index).dataType, nullable = a.nullable)
49+
BoundReference(index, partitionAttrs(index).dataType, nullable = a.nullable)
4650
}
4751
val predicate = CatalystPredicate.createInterpreted(boundExpr)
4852
predicate.eval
4953
}
5054

5155
override def eval(partitionValues: InternalRow): Boolean = {
52-
if (partitionValues.numFields != partitionSchema.length) {
56+
if (partitionValues.numFields != partitionFields.length) {
5357
logWarning(
5458
log"Cannot evaluate partition predicate ${MDC(LogKeys.EXPR, catalystExpr.sql)}: " +
5559
log"partition value field count (${MDC(LogKeys.COUNT, partitionValues.numFields)}) " +
56-
log"does not match schema (${MDC(LogKeys.NUM_PARTITIONS, partitionSchema.length)}). " +
60+
log"does not match schema (${MDC(LogKeys.NUM_PARTITIONS, partitionFields.length)}). " +
5761
log"Including partition in scan result to avoid incorrect filtering.")
5862
return true
5963
}
@@ -72,20 +76,24 @@ class PartitionPredicateImpl private (
7276

7377
@transient override lazy val references: Array[NamedReference] = {
7478
val refNames = catalystExpr.references.map(_.name).toSet
75-
partitionSchema.zipWithIndex
79+
partitionAttrs.zipWithIndex
7680
.filter { case (attr, _) => refNames.contains(attr.name) }
77-
.map { case (attr, ordinal) => PartitionColumnReferenceImpl(ordinal, Array(attr.name)) }
81+
.map { case (_, ordinal) =>
82+
PartitionFieldReferenceImpl(
83+
ordinal, partitionFields(ordinal).identityRef.fieldNames())
84+
}
7885
.toArray
7986
}
8087

8188
override def equals(obj: Any): Boolean = obj match {
8289
case other: PartitionPredicateImpl =>
83-
catalystExpr.semanticEquals(other.catalystExpr) && partitionSchema == other.partitionSchema
90+
catalystExpr.semanticEquals(other.catalystExpr) &&
91+
partitionFields == other.partitionFields
8492
case _ => false
8593
}
8694

8795
override def hashCode(): Int = {
88-
31 * catalystExpr.semanticHash() + partitionSchema.hashCode()
96+
31 * catalystExpr.semanticHash() + partitionFields.hashCode()
8997
}
9098

9199
override def toString(): String = s"PartitionPredicate(${catalystExpr.sql})"
@@ -95,18 +103,26 @@ object PartitionPredicateImpl {
95103

96104
def apply(
97105
catalystExpr: CatalystExpression,
98-
partitionSchema: Seq[AttributeReference]): PartitionPredicateImpl = {
99-
if (partitionSchema.isEmpty) {
106+
partitionFields: Seq[PartitionPredicateField]): PartitionPredicateImpl = {
107+
validateAndCreate(catalystExpr, partitionFields)
108+
}
109+
110+
private def validateAndCreate(
111+
catalystExpr: CatalystExpression,
112+
partitionFields: Seq[PartitionPredicateField]): PartitionPredicateImpl = {
113+
if (partitionFields.isEmpty) {
100114
throw SparkException.internalError(
101-
s"Cannot evaluate partition predicate ${catalystExpr.sql}: partition schema is empty")
115+
s"Cannot evaluate partition predicate ${catalystExpr.sql}: partition fields are empty")
102116
}
103-
val partitionNames = partitionSchema.map(_.name).toSet
117+
val partitionNames = partitionFields.map(_.structField.name).toSet
104118
val refNames = catalystExpr.references.map(_.name).toSet
105119
if (!refNames.subsetOf(partitionNames)) {
120+
val refsStr = refNames.mkString(", ")
121+
val fieldsStr = partitionNames.mkString(", ")
106122
throw SparkException.internalError(
107123
s"Cannot evaluate partition predicate ${catalystExpr.sql}: expression references " +
108-
s"${refNames.mkString(", ")} not all in partition columns ${partitionNames.mkString(", ")}")
124+
s"$refsStr not all in partition fields $fieldsStr")
109125
}
110-
new PartitionPredicateImpl(catalystExpr, partitionSchema)
126+
new PartitionPredicateImpl(catalystExpr, partitionFields)
111127
}
112128
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedPartitionFilterTable.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ class InMemoryEnhancedPartitionFilterTable(
7676
override def supportsIterativePushdown(): Boolean = true
7777

7878
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
79-
val partNames = InMemoryEnhancedPartitionFilterTable.this.partCols.flatMap(_.toSeq).toSet
79+
val partPaths = InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.mkString(".")).toSet
8080
def referencesOnlyPartitionCols(p: Predicate): Boolean =
81-
p.references().forall(ref => partNames.contains(ref.fieldNames().mkString(".")))
81+
p.references().forall(ref => partPaths.contains(ref.fieldNames().mkString(".")))
8282
def referencesOnlyDataCols(p: Predicate): Boolean =
83-
p.references().forall(ref => !partNames.contains(ref.fieldNames().mkString(".")))
83+
p.references().forall(ref => !partPaths.contains(ref.fieldNames().mkString(".")))
8484

8585
val returned = ArrayBuffer.empty[Predicate]
8686

@@ -120,10 +120,11 @@ class InMemoryEnhancedPartitionFilterTable(
120120
val partNames =
121121
InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.toSeq.quoted)
122122
.toImmutableArraySeq
123-
val partNamesSet = InMemoryEnhancedPartitionFilterTable.this.partCols.flatMap(_.toSeq).toSet
123+
val partPathSet =
124+
InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.mkString(".")).toSet
124125
// Only partition predicates can be used for partition key filtering (filtersToKeys).
125126
val firstPassPartitionPredicates = firstPassPushedPredicates.filter { p =>
126-
p.references().forall(ref => partNamesSet.contains(ref.fieldNames().mkString(".")))
127+
p.references().forall(ref => partPathSet.contains(ref.fieldNames().mkString(".")))
127128
}
128129
val allKeys = allPartitions.map(_.asInstanceOf[BufferedRows].key)
129130
val matchingKeys = InMemoryTableWithV2Filter.filtersToKeys(

sql/catalyst/src/test/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImplSuite.scala

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ package org.apache.spark.sql.internal.connector
2020
import org.apache.spark.{SparkConf, SparkFunSuite}
2121
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal}
24-
import org.apache.spark.sql.connector.expressions.PartitionColumnReference
25-
import org.apache.spark.sql.types.IntegerType
23+
import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal}
24+
import org.apache.spark.sql.catalyst.types.DataTypeUtils
25+
import org.apache.spark.sql.connector.expressions.{FieldReference, PartitionFieldReference}
26+
import org.apache.spark.sql.types.{IntegerType, StringType, StructField}
27+
import org.apache.spark.unsafe.types.UTF8String
2628

2729
class PartitionPredicateImplSuite extends SparkFunSuite {
2830

@@ -38,12 +40,25 @@ class PartitionPredicateImplSuite extends SparkFunSuite {
3840
checkPartitionPredicateImplAfterSerialization(serializer)
3941
}
4042

43+
test("Kryo: nested partition path in references survives round-trip") {
44+
val conf = new SparkConf()
45+
val serializer = new KryoSerializer(conf).newInstance()
46+
checkNestedPartitionPathReferencesAfterSerialization(serializer)
47+
}
48+
49+
test("Java serialization: nested partition path in references survives round-trip") {
50+
val conf = new SparkConf()
51+
val serializer = new JavaSerializer(conf).newInstance()
52+
checkNestedPartitionPathReferencesAfterSerialization(serializer)
53+
}
54+
4155
private def checkPartitionPredicateImplAfterSerialization(
4256
serializer: SerializerInstance): Unit = {
43-
val partitionSchema = Seq(AttributeReference("p", IntegerType)())
44-
val ref = AttributeReference("p", IntegerType)()
57+
val field = StructField("p", IntegerType, nullable = true)
58+
val ref = DataTypeUtils.toAttribute(field)
4559
val expr = GreaterThan(ref, Literal(5))
46-
val predicate = PartitionPredicateImpl(expr, partitionSchema)
60+
val fields = Seq(PartitionPredicateField(field, FieldReference(Seq("p"))))
61+
val predicate = PartitionPredicateImpl(expr, fields)
4762

4863
val deserialized = serializer.deserialize[PartitionPredicateImpl](
4964
serializer.serialize(predicate))
@@ -59,8 +74,36 @@ class PartitionPredicateImplSuite extends SparkFunSuite {
5974
assert(deserialized.equals(predicate))
6075
}
6176

77+
private def checkNestedPartitionPathReferencesAfterSerialization(
78+
serializer: SerializerInstance): Unit = {
79+
val field = StructField("ts.timezone", StringType, nullable = false)
80+
val ref = DataTypeUtils.toAttribute(field)
81+
val expr = GreaterThan(ref, Literal("x"))
82+
val fields = Seq(PartitionPredicateField(field, FieldReference(Seq("ts", "timezone"))))
83+
val predicate = PartitionPredicateImpl(expr, fields)
84+
85+
val deserialized = serializer.deserialize[PartitionPredicateImpl](
86+
serializer.serialize(predicate))
87+
88+
assert(deserialized.eval(InternalRow(UTF8String.fromString("z"))) === true)
89+
assert(deserialized.eval(InternalRow(UTF8String.fromString("a"))) === false)
90+
91+
val expectedRefs = Seq((0, Seq("ts", "timezone")))
92+
assert(partitionRefDetails(predicate.references.toSeq) === expectedRefs)
93+
assert(partitionRefDetails(deserialized.references.toSeq) === expectedRefs)
94+
95+
assert(deserialized.equals(predicate))
96+
}
97+
98+
private def partitionRefDetails(refs: Seq[AnyRef]): Seq[(Int, Seq[String])] = refs.map {
99+
case r: PartitionFieldReference =>
100+
(r.ordinal(), r.fieldNames().toIndexedSeq)
101+
case other =>
102+
fail(s"Expected PartitionColumnReference, got ${other.getClass.getName}: $other")
103+
}
104+
62105
private def refsWithOrdinals(refs: Seq[AnyRef]): Seq[(String, Int)] = refs.map {
63-
case r: PartitionColumnReference =>
106+
case r: PartitionFieldReference =>
64107
(r.fieldNames().mkString("."), r.ordinal())
65108
case other =>
66109
fail(s"Expected PartitionColumnReference, got ${other.getClass.getName}: $other")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20-
import org.apache.spark.internal.{LogKeys}
20+
import org.apache.spark.internal.LogKeys
2121
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
2222
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2323
import org.apache.spark.sql.catalyst.planning.{GroupBasedRowLevelOperation, PhysicalOperation}
@@ -27,8 +27,8 @@ import org.apache.spark.sql.connector.expressions.filter.{Predicate => V2Filter}
2727
import org.apache.spark.sql.connector.read.ScanBuilder
2828
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
2929
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
30+
import org.apache.spark.sql.internal.connector.PartitionPredicateField
3031
import org.apache.spark.sql.sources.Filter
31-
import org.apache.spark.sql.types.StructType
3232

3333
/**
3434
* A rule that builds scans for group-based row-level operations.
@@ -48,10 +48,10 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr
4848

4949
val table = relation.table.asRowLevelOperationTable
5050
val scanBuilder = table.newScanBuilder(relation.options)
51-
val partitionSchema = PushDownUtils.getPartitionPredicateSchema(relation)
51+
val partitionPredicateFields = PushDownUtils.getPartitionSchemaInfo(relation)
5252

5353
val (pushedFilters, evaluatedFilters, postScanFilters) =
54-
pushFilters(cond, relation.output, scanBuilder, partitionSchema)
54+
pushFilters(cond, relation.output, scanBuilder, partitionPredicateFields)
5555

5656
val pushedFiltersStr = if (pushedFilters.isLeft) {
5757
pushedFilters.swap
@@ -100,13 +100,13 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr
100100
cond: Expression,
101101
tableAttrs: Seq[AttributeReference],
102102
scanBuilder: ScanBuilder,
103-
partitionSchema: Option[StructType])
103+
partitionPredicateFields: Option[Seq[PartitionPredicateField]])
104104
: (Either[Seq[Filter], Seq[V2Filter]], Seq[Expression], Seq[Expression]) = {
105105

106106
val (filtersWithSubquery, filtersWithoutSubquery) = findTableFilters(cond, tableAttrs)
107107

108108
val (pushedFilters, postScanFiltersWithoutSubquery) =
109-
PushDownUtils.pushFilters(scanBuilder, filtersWithoutSubquery, partitionSchema)
109+
PushDownUtils.pushFilters(scanBuilder, filtersWithoutSubquery, partitionPredicateFields)
110110

111111
val postScanFilterSetWithoutSubquery = ExpressionSet(postScanFiltersWithoutSubquery)
112112
val evaluatedFilters = filtersWithoutSubquery.filterNot { filter =>

0 commit comments

Comments
 (0)