diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionColumnReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionFieldReference.java similarity index 78% rename from sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionColumnReference.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionFieldReference.java index ef516545812b9..0a51f0dfd8540 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionColumnReference.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/PartitionFieldReference.java @@ -21,19 +21,19 @@ import org.apache.spark.sql.connector.catalog.Table; /** - * A reference to a partition column in {@link Table#partitioning()}. + * A reference to a partition field in {@link Table#partitioning()}. *
- * {@link #fieldNames()} returns the partition column name (or names) as reported by + * {@link #fieldNames()} returns the partition field name (or names) as reported by * the table's partition schema. * {@link #ordinal()} returns the 0-based position in {@link Table#partitioning()}. * * @since 4.2.0 */ @Evolving -public interface PartitionColumnReference extends NamedReference { +public interface PartitionFieldReference extends NamedReference { /** - * Returns the 0-based ordinal of this partition column in {@link Table#partitioning()}. + * Returns the 0-based ordinal of this partition field in {@link Table#partitioning()}. */ int ordinal(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/PartitionPredicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/PartitionPredicate.java index dbc31aa1a4583..1e579a39f27c6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/PartitionPredicate.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/PartitionPredicate.java @@ -21,9 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.expressions.NamedReference; -import org.apache.spark.sql.connector.expressions.PartitionColumnReference; - -import static org.apache.spark.sql.connector.expressions.Expression.EMPTY_EXPRESSION; +import org.apache.spark.sql.connector.expressions.PartitionFieldReference; /** * Represents a partition predicate that can be evaluated using {@link Table#partitioning()}. @@ -47,17 +45,17 @@ protected PartitionPredicate() { /** * {@inheritDoc} *
- * For PartitionPredicate, returns {@link PartitionColumnReference} instances that identify - * the partition columns (from {@link Table#partitioning()}) referenced by this predicate. - * Each reference's {@link PartitionColumnReference#fieldNames()} gives the partition column - * name; {@link PartitionColumnReference#ordinal()} gives the 0-based position in + * For PartitionPredicate, returns {@link PartitionFieldReference} instances that identify + * the partition fields (from {@link Table#partitioning()}) referenced by this predicate. + * Each reference's {@link PartitionFieldReference#fieldNames()} gives the partition field + * name; {@link PartitionFieldReference#ordinal()} gives the 0-based position in * {@link Table#partitioning()}. *
* Example: Suppose {@code Table.partitioning()} returns three partition * transforms: {@code [years(ts), months(ts), bucket(32, id)]} with ordinals 0, 1, 2. - * Each {@link PartitionColumnReference} has {@link PartitionColumnReference#fieldNames()} + * Each {@link PartitionFieldReference} has {@link PartitionFieldReference#fieldNames()} * (the transform display name, e.g. {@code years(ts)}) and - * {@link PartitionColumnReference#ordinal()}: + * {@link PartitionFieldReference#ordinal()}: *
* For each {@link PartitionPredicate}, the implementation can use
- * {@link PartitionPredicate#references()} (each {@link PartitionColumnReference} has
- * {@link PartitionColumnReference#ordinal()}) to decide whether to return it for post-scan
+ * {@link PartitionPredicate#references()} (each {@link PartitionFieldReference} has
+ * {@link PartitionFieldReference#ordinal()}) to decide whether to return it for post-scan
* filtering. For example, data sources with
* partition spec evolution may return predicates that reference later-added partition
* transforms (incompletely partitioned data) so Spark evaluates them after the scan, while
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionColumnReferenceImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionFieldReferenceImpl.scala
similarity index 73%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionColumnReferenceImpl.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionFieldReferenceImpl.scala
index af069201483e2..6ca8c74379294 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionColumnReferenceImpl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionFieldReferenceImpl.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.internal.connector
-import org.apache.spark.sql.connector.expressions.PartitionColumnReference
+import org.apache.spark.sql.connector.expressions.PartitionFieldReference
/**
- * Implementation of [[PartitionColumnReference]] that carries the position ordinal in
- * Table.partitioning() and the partition column name(s) for that position.
+ * Implementation of [[PartitionFieldReference]] that carries the position ordinal in
+ * Table.partitioning() and the partition field name(s) for that position.
*/
-private[connector] case class PartitionColumnReferenceImpl(
+private[connector] case class PartitionFieldReferenceImpl(
ordinal: Int,
fieldNames: Array[String])
- extends PartitionColumnReference
+ extends PartitionFieldReference
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateField.scala
new file mode 100644
index 0000000000000..f9cdbb97a5992
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateField.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.types.StructField
+
+/**
+ * Metadata for one partition field.
+ *
+ * @param structField the Catalyst field describing the partition column (name, type,
+ * nullability). For nested columns the name is the dotted path
+ * (e.g. `"s.tz"`).
+ * @param identityRef the [[NamedReference]] from the transform target column
+ * (e.g. `FieldReference(Seq("s", "tz"))`).
+ */
+case class PartitionPredicateField(structField: StructField, identityRef: NamedReference)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImpl.scala
index 6c887ab7d8076..eccaee49daf12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImpl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImpl.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Expression => CatalystExpression, Predicate => CatalystPredicate}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate
@@ -30,9 +31,12 @@ import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate
*/
class PartitionPredicateImpl private (
private val catalystExpr: CatalystExpression,
- private val partitionSchema: Seq[AttributeReference])
+ private val partitionFields: Seq[PartitionPredicateField])
extends PartitionPredicate with Logging {
+ @transient private lazy val partitionAttrs: Seq[AttributeReference] =
+ partitionFields.map(c => DataTypeUtils.toAttribute(c.structField))
+
/** The wrapped partition filter Catalyst Expression. */
def expression: CatalystExpression = catalystExpr
@@ -40,20 +44,20 @@ class PartitionPredicateImpl private (
@transient private lazy val boundPredicate: InternalRow => Boolean = {
val boundExpr = catalystExpr.transform {
case a: AttributeReference =>
- val index = partitionSchema.indexWhere(_.name == a.name)
- require(index >= 0, s"Column ${a.name} not found in partition schema")
- BoundReference(index, partitionSchema(index).dataType, nullable = a.nullable)
+ val index = partitionAttrs.indexWhere(_.name == a.name)
+ require(index >= 0, s"Field ${a.name} not found in partition schema")
+ BoundReference(index, partitionAttrs(index).dataType, nullable = a.nullable)
}
val predicate = CatalystPredicate.createInterpreted(boundExpr)
predicate.eval
}
override def eval(partitionValues: InternalRow): Boolean = {
- if (partitionValues.numFields != partitionSchema.length) {
+ if (partitionValues.numFields != partitionFields.length) {
logWarning(
log"Cannot evaluate partition predicate ${MDC(LogKeys.EXPR, catalystExpr.sql)}: " +
log"partition value field count (${MDC(LogKeys.COUNT, partitionValues.numFields)}) " +
- log"does not match schema (${MDC(LogKeys.NUM_PARTITIONS, partitionSchema.length)}). " +
+ log"does not match schema (${MDC(LogKeys.NUM_PARTITIONS, partitionFields.length)}). " +
log"Including partition in scan result to avoid incorrect filtering.")
return true
}
@@ -72,20 +76,24 @@ class PartitionPredicateImpl private (
@transient override lazy val references: Array[NamedReference] = {
val refNames = catalystExpr.references.map(_.name).toSet
- partitionSchema.zipWithIndex
+ partitionAttrs.zipWithIndex
.filter { case (attr, _) => refNames.contains(attr.name) }
- .map { case (attr, ordinal) => PartitionColumnReferenceImpl(ordinal, Array(attr.name)) }
+ .map { case (_, ordinal) =>
+ PartitionFieldReferenceImpl(ordinal,
+ partitionFields(ordinal).identityRef.fieldNames())
+ }
.toArray
}
override def equals(obj: Any): Boolean = obj match {
case other: PartitionPredicateImpl =>
- catalystExpr.semanticEquals(other.catalystExpr) && partitionSchema == other.partitionSchema
+ catalystExpr.semanticEquals(other.catalystExpr) &&
+ partitionFields == other.partitionFields
case _ => false
}
override def hashCode(): Int = {
- 31 * catalystExpr.semanticHash() + partitionSchema.hashCode()
+ 31 * catalystExpr.semanticHash() + partitionFields.hashCode()
}
override def toString(): String = s"PartitionPredicate(${catalystExpr.sql})"
@@ -93,20 +101,29 @@ class PartitionPredicateImpl private (
object PartitionPredicateImpl {
- def apply(
+ def apply(catalystExpr: CatalystExpression,
+ partitionFields: Seq[PartitionPredicateField])
+ : PartitionPredicateImpl = {
+ validateAndCreate(catalystExpr, partitionFields)
+ }
+
+ private def validateAndCreate(
catalystExpr: CatalystExpression,
- partitionSchema: Seq[AttributeReference]): PartitionPredicateImpl = {
- if (partitionSchema.isEmpty) {
+ partitionFields: Seq[PartitionPredicateField])
+ : PartitionPredicateImpl = {
+ if (partitionFields.isEmpty) {
throw SparkException.internalError(
- s"Cannot evaluate partition predicate ${catalystExpr.sql}: partition schema is empty")
+ s"Cannot evaluate partition predicate ${catalystExpr.sql}: partition fields are empty")
}
- val partitionNames = partitionSchema.map(_.name).toSet
+ val partitionNames = partitionFields.map(_.structField.name).toSet
val refNames = catalystExpr.references.map(_.name).toSet
if (!refNames.subsetOf(partitionNames)) {
+ val refsStr = refNames.mkString(", ")
+ val fieldsStr = partitionNames.mkString(", ")
throw SparkException.internalError(
s"Cannot evaluate partition predicate ${catalystExpr.sql}: expression references " +
- s"${refNames.mkString(", ")} not all in partition columns ${partitionNames.mkString(", ")}")
+ s"$refsStr not all in partition fields $fieldsStr")
}
- new PartitionPredicateImpl(catalystExpr, partitionSchema)
+ new PartitionPredicateImpl(catalystExpr, partitionFields)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedPartitionFilterTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedPartitionFilterTable.scala
index 507439b9d06c7..4cfca8a62f579 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedPartitionFilterTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryEnhancedPartitionFilterTable.scala
@@ -76,11 +76,11 @@ class InMemoryEnhancedPartitionFilterTable(
override def supportsIterativePushdown(): Boolean = true
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
- val partNames = InMemoryEnhancedPartitionFilterTable.this.partCols.flatMap(_.toSeq).toSet
+ val partPaths = InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.mkString(".")).toSet
def referencesOnlyPartitionCols(p: Predicate): Boolean =
- p.references().forall(ref => partNames.contains(ref.fieldNames().mkString(".")))
+ p.references().forall(ref => partPaths.contains(ref.fieldNames().mkString(".")))
def referencesOnlyDataCols(p: Predicate): Boolean =
- p.references().forall(ref => !partNames.contains(ref.fieldNames().mkString(".")))
+ p.references().forall(ref => !partPaths.contains(ref.fieldNames().mkString(".")))
val returned = ArrayBuffer.empty[Predicate]
@@ -120,10 +120,11 @@ class InMemoryEnhancedPartitionFilterTable(
val partNames =
InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.toSeq.quoted)
.toImmutableArraySeq
- val partNamesSet = InMemoryEnhancedPartitionFilterTable.this.partCols.flatMap(_.toSeq).toSet
+ val partPathSet =
+ InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.mkString(".")).toSet
// Only partition predicates can be used for partition key filtering (filtersToKeys).
val firstPassPartitionPredicates = firstPassPushedPredicates.filter { p =>
- p.references().forall(ref => partNamesSet.contains(ref.fieldNames().mkString(".")))
+ p.references().forall(ref => partPathSet.contains(ref.fieldNames().mkString(".")))
}
val allKeys = allPartitions.map(_.asInstanceOf[BufferedRows].key)
val matchingKeys = InMemoryTableWithV2Filter.filtersToKeys(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImplSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImplSuite.scala
index 432d0df7d8d4b..593c9eef04c41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImplSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/internal/connector/PartitionPredicateImplSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql.internal.connector
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal}
-import org.apache.spark.sql.connector.expressions.PartitionColumnReference
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.expressions.{FieldReference, PartitionFieldReference}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField}
+import org.apache.spark.unsafe.types.UTF8String
class PartitionPredicateImplSuite extends SparkFunSuite {
@@ -38,12 +40,25 @@ class PartitionPredicateImplSuite extends SparkFunSuite {
checkPartitionPredicateImplAfterSerialization(serializer)
}
+ test("Kryo: nested partition path in references survives round-trip") {
+ val conf = new SparkConf()
+ val serializer = new KryoSerializer(conf).newInstance()
+ checkNestedPartitionPathReferencesAfterSerialization(serializer)
+ }
+
+ test("Java serialization: nested partition path in references survives round-trip") {
+ val conf = new SparkConf()
+ val serializer = new JavaSerializer(conf).newInstance()
+ checkNestedPartitionPathReferencesAfterSerialization(serializer)
+ }
+
private def checkPartitionPredicateImplAfterSerialization(
serializer: SerializerInstance): Unit = {
- val partitionSchema = Seq(AttributeReference("p", IntegerType)())
- val ref = AttributeReference("p", IntegerType)()
+ val field = StructField("p", IntegerType, nullable = true)
+ val ref = DataTypeUtils.toAttribute(field)
val expr = GreaterThan(ref, Literal(5))
- val predicate = PartitionPredicateImpl(expr, partitionSchema)
+ val fields = Seq(PartitionPredicateField(field, FieldReference(Seq("p"))))
+ val predicate = PartitionPredicateImpl(expr, fields)
val deserialized = serializer.deserialize[PartitionPredicateImpl](
serializer.serialize(predicate))
@@ -59,10 +74,38 @@ class PartitionPredicateImplSuite extends SparkFunSuite {
assert(deserialized.equals(predicate))
}
+ private def checkNestedPartitionPathReferencesAfterSerialization(
+ serializer: SerializerInstance): Unit = {
+ val field = StructField("ts.timezone", StringType, nullable = false)
+ val ref = DataTypeUtils.toAttribute(field)
+ val expr = GreaterThan(ref, Literal("x"))
+ val fields = Seq(PartitionPredicateField(field, FieldReference(Seq("ts", "timezone"))))
+ val predicate = PartitionPredicateImpl(expr, fields)
+
+ val deserialized = serializer.deserialize[PartitionPredicateImpl](
+ serializer.serialize(predicate))
+
+ assert(deserialized.eval(InternalRow(UTF8String.fromString("z"))) === true)
+ assert(deserialized.eval(InternalRow(UTF8String.fromString("a"))) === false)
+
+ val expectedRefs = Seq((0, Seq("ts", "timezone")))
+ assert(partitionRefDetails(predicate.references.toSeq) === expectedRefs)
+ assert(partitionRefDetails(deserialized.references.toSeq) === expectedRefs)
+
+ assert(deserialized.equals(predicate))
+ }
+
+ private def partitionRefDetails(refs: Seq[AnyRef]): Seq[(Int, Seq[String])] = refs.map {
+ case r: PartitionFieldReference =>
+ (r.ordinal(), r.fieldNames().toIndexedSeq)
+ case other =>
+ fail(s"Expected PartitionFieldReference, got ${other.getClass.getName}: $other")
+ }
+
private def refsWithOrdinals(refs: Seq[AnyRef]): Seq[(String, Int)] = refs.map {
- case r: PartitionColumnReference =>
+ case r: PartitionFieldReference =>
(r.fieldNames().mkString("."), r.ordinal())
case other =>
- fail(s"Expected PartitionColumnReference, got ${other.getClass.getName}: $other")
+ fail(s"Expected PartitionFieldReference, got ${other.getClass.getName}: $other")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
index 8843fe105ec30..0cb7967cb7369 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupBasedRowLevelOperationScanPlanning.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.internal.{LogKeys}
+import org.apache.spark.internal.LogKeys
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.{GroupBasedRowLevelOperation, PhysicalOperation}
@@ -27,8 +27,8 @@ import org.apache.spark.sql.connector.expressions.filter.{Predicate => V2Filter}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.internal.connector.PartitionPredicateField
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
/**
* A rule that builds scans for group-based row-level operations.
@@ -48,10 +48,10 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)
- val partitionSchema = PushDownUtils.getPartitionPredicateSchema(relation)
+ val partitionPredicateFields = PushDownUtils.getPartitionPredicateSchema(relation)
val (pushedFilters, evaluatedFilters, postScanFilters) =
- pushFilters(cond, relation.output, scanBuilder, partitionSchema)
+ pushFilters(cond, relation.output, scanBuilder, partitionPredicateFields)
val pushedFiltersStr = if (pushedFilters.isLeft) {
pushedFilters.swap
@@ -100,13 +100,13 @@ object GroupBasedRowLevelOperationScanPlanning extends Rule[LogicalPlan] with Pr
cond: Expression,
tableAttrs: Seq[AttributeReference],
scanBuilder: ScanBuilder,
- partitionSchema: Option[StructType])
+ partitionPredicateFields: Option[Seq[PartitionPredicateField]])
: (Either[Seq[Filter], Seq[V2Filter]], Seq[Expression], Seq[Expression]) = {
val (filtersWithSubquery, filtersWithoutSubquery) = findTableFilters(cond, tableAttrs)
val (pushedFilters, postScanFiltersWithoutSubquery) =
- PushDownUtils.pushFilters(scanBuilder, filtersWithoutSubquery, partitionSchema)
+ PushDownUtils.pushFilters(scanBuilder, filtersWithoutSubquery, partitionPredicateFields)
val postScanFilterSetWithoutSubquery = ExpressionSet(postScanFiltersWithoutSubquery)
val evaluatedFilters = filtersWithoutSubquery.filterNot { filter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 4a87a50c6576e..f682be51a8925 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -19,36 +19,38 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, NamedExpression, PythonUDF, SchemaPruning, SubqueryExpression}
+import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, NamedExpression, PythonUDF, SchemaPruning, SubqueryExpression}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.expressions.{IdentityTransform, SortOrder, Transform}
+import org.apache.spark.sql.connector.expressions.{IdentityTransform, SortOrder}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownOffset, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.connector.{PartitionPredicateImpl, SupportsPushDownCatalystFilters}
+import org.apache.spark.sql.internal.connector.{PartitionPredicateField, PartitionPredicateImpl, SupportsPushDownCatalystFilters}
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.ArrayImplicits.SparkArrayOps
import org.apache.spark.util.collection.Utils
-object PushDownUtils {
+object PushDownUtils extends Logging {
/**
* Pushes down filters to the data source reader.
*
* @param scanBuilder The scan builder to push filters to.
* @param filters Catalyst filter expressions to push down.
- * @param partitionSchema The schema of [[Table#partitioning() partitioning]].
- * When set and the scan supports V2 filters,
- * [[PartitionPredicate]] can be pushed for a second pass.
+ * @param partitionFields When non-empty, metadata for [[Table#partitioning()]].
+ * When set and the scan supports V2 filters,
+ * [[PartitionPredicate]] can be pushed for a second pass.
* @return pushed filter and post-scan filters.
*/
def pushFilters(
scanBuilder: ScanBuilder,
filters: Seq[Expression],
- partitionSchema: Option[StructType])
+ partitionFields: Option[Seq[PartitionPredicateField]])
: (Either[Seq[sources.Filter], Seq[Predicate]], Seq[Expression]) = {
scanBuilder match {
case r: SupportsPushDownFilters =>
@@ -109,11 +111,12 @@ object PushDownUtils {
}
val remainingFilters = (rejectedFilters ++ untranslatableExprs).toSeq
- val postScanFilters = if (partitionSchema.isEmpty || !r.supportsIterativePushdown) {
- remainingFilters
- } else {
- pushPartitionPredicates(r, partitionSchema.get, remainingFilters)
- }
+ val postScanFilters =
+ if (!partitionFields.exists(_.nonEmpty) || !r.supportsIterativePushdown) {
+ remainingFilters
+ } else {
+ pushPartitionPredicates(r, partitionFields.get, remainingFilters)
+ }
val orderedPostScanFilters = prioritizeFilters(postScanFilters,
ExpressionSet(untranslatableExprs))
@@ -126,77 +129,164 @@ object PushDownUtils {
}
/**
- * Normally translated filters (postScanFilters) are simple filters that can be
- * evaluated faster, while the untranslated filters are complicated filters
- * that take more time to evaluate, so we want to evaluate the translatable
- * filters first.
+ * Return a Seq of [[PartitionPredicateField]] representing partition transform expression types,
+ * if schema is supported for [[PartitionPredicate]] push down. None if not supported.
*/
- private def prioritizeFilters(
- filters: Seq[Expression],
- untranslatableFilterSet: ExpressionSet): Seq[Expression] = {
- val (translatable, untranslatable) = filters.partition(!untranslatableFilterSet.contains(_))
- translatable ++ untranslatable
+ def getPartitionPredicateSchema(relation: DataSourceV2Relation)
+ : Option[Seq[PartitionPredicateField]] = {
+ val transforms = relation.table.partitioning
+ if (transforms.isEmpty) {
+ None
+ } else {
+ val rootStruct = StructType(relation.output.map { a =>
+ StructField(a.name, a.dataType, a.nullable)})
+ val fields = transforms.flatMap {
+ case t: IdentityTransform =>
+ resolveIdentityPartitionField(t, rootStruct).map(PartitionPredicateField(_, t.ref))
+ case _ => None
+ }
+ if (fields.length == transforms.length) {
+ Some(fields.toSeq)
+ } else {
+ None
+ }
+ }
+ }
+
+ /**
+ * Returns a [[StructField]] for the given identity partition
+ * transform if it can be resolved, or `None` if it cannot be resolved.
+ */
+ private def resolveIdentityPartitionField(
+ transform: IdentityTransform,
+ rootStruct: StructType) = {
+ val names = transform.ref.fieldNames().toSeq
+ try {
+ rootStruct.findNestedField(names, resolver = SQLConf.get.resolver).map {
+ case (_, leaf) => StructField(names.mkString("."), leaf.dataType, leaf.nullable)
+ }
+ } catch {
+ case _: AnalysisException =>
+ logWarning(log"Invalid partition reference: " +
+ log"${MDC(LogKeys.FIELD_NAME, names.mkString("."))}," +
+ log" skipping creation of PartitionPredicate.")
+ None
+ }
}
/**
- * If the scan supports iterative filtering, convert partition filters to
- * PartitionPredicates (see SPARK-55596) and push them down in another pass.
+ * If the scan supports iterative filtering, infer additional partition filters,
+ * convert these and unused partition filters to PartitionPredicates,
+ * and push them down in another pass. (See SPARK-55596)
*/
private def pushPartitionPredicates(
scanBuilder: SupportsPushDownV2Filters,
- partitionSchema: StructType,
+ partitionFields: Seq[PartitionPredicateField],
remainingFilters: Seq[Expression]): Seq[Expression] = {
- val (partitionFilters, nonPartitionFilters) =
- DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, remainingFilters)
- val (pushable, nonPushable) = partitionFilters.partition(isPushablePartitionFilter)
- val partitionAttrs = toAttributes(partitionSchema)
- val partitionPredicates = pushable.map(expr => PartitionPredicateImpl(expr, partitionAttrs))
+ val normalizedToOriginal = normalizeNestedPartitionFilters(remainingFilters, partitionFields)
+ val normalized = normalizedToOriginal.keys.toSeq
+ val partitionSchema = StructType(partitionFields.map(_.structField))
+ // may infer additional partition filters
+ val (partFilters, nonPartitionFilters) =
+ DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, normalized)
+ val (pushable, nonPushable) = partFilters.partition(isPushablePartitionFilter)
+ val partitionPredicates = pushable.map(PartitionPredicateImpl(_, partitionFields))
val rejectedPartitionFilters = scanBuilder.pushPredicates(partitionPredicates.toArray).map {
- predicate => predicate.asInstanceOf[PartitionPredicateImpl].expression
- }
- nonPartitionFilters ++ nonPushable ++ rejectedPartitionFilters
+ p => p.asInstanceOf[PartitionPredicateImpl].expression
+ }.toSeq
+ (nonPartitionFilters ++ nonPushable ++ rejectedPartitionFilters)
+ .filter(normalizedToOriginal.contains)
+ .map(normalizedToOriginal)
}
+ private def isPushablePartitionFilter(f: Expression) =
+ f.deterministic &&
+ !SubqueryExpression.hasSubquery(f) &&
+ !f.exists(_.isInstanceOf[PythonUDF])
+
/**
- * Returns a table's partitioning expression schema as a StructType, if creation of a
- * PartitionPredicate is supported for the schema.
- * Currently only supported if all partitioning expressions are identity transforms on simple
- * (single-name, non-nested) field references.
+ * Normalizes filter expressions so that nested struct accesses on
+ * partition fields are replaced with flat [[AttributeReference]]s
+ * whose names match the partition schema.
+ *
+ * For example, given a table partitioned by `s.tz` (identity
+ * transform on a nested field), the analyzer produces
+ * `GetStructField(attr("s"), "tz")`. This method replaces that
+ * chain with `attr("s.tz")`.
*
- * @return Some(StructType) representing partition transform expression types, if schema
- * is supported for PartitionPredicate. None if not supported.
+ * Returns a map from normalized expression to original.
*/
- def getPartitionPredicateSchema(relation: DataSourceV2Relation): Option[StructType] = {
- val transforms = relation.table.partitioning
- val fields = transforms.flatMap(toSupportedPartitionField(_, relation))
- Option.when(transforms.nonEmpty && fields.length == transforms.length)(StructType(fields))
+ private def normalizeNestedPartitionFilters(
+ filters: Seq[Expression],
+ partitionFields: Seq[PartitionPredicateField])
+ : Map[Expression, Expression] = {
+ val attrs = toAttributes(StructType(partitionFields.map(_.structField)))
+ val pathToAttr = partitionFields.map(_.identityRef).zip(attrs).map {
+ case (r, a) => r.fieldNames().toSeq -> a
+ }.toMap
+ filters.map { f =>
+ doNormalizePartitionFilters(f, pathToAttr, SQLConf.get.resolver) -> f
+ }.toMap
+ }
+
+ private def doNormalizePartitionFilters(
+ expr: Expression,
+ pathToAttr: Map[Seq[String], AttributeReference],
+ resolver: (String, String) => Boolean): Expression = {
+ expr.mapChildren(
+ doNormalizePartitionFilters(_, pathToAttr, resolver)
+ ) match {
+ case g: GetStructField =>
+ flattenStructFieldChain(g) match {
+ case Some((root, suffix)) =>
+ val fullPath = root.name +: suffix
+ pathToAttr.collectFirst {
+ case (path, attr) if pathsMatch(fullPath, path, resolver) =>
+ attr.withNullability(g.nullable)
+ }.getOrElse(g) // Not a partition field
+ case None => g // Single-level struct access, not nested
+ }
+ case other => other
+ }
}
/**
- * Returns a StructField for the given partition transform if it is
- * supported for iterative partition predicate push down.
+ * Flattens a nested [[GetStructField]] chain into the root
+ * [[AttributeReference]] and its field-name path. Returns `None`
+ * when the expression is not a nested struct access.
+ *
+ * Example: `GetStructField(GetStructField(a#1, "x"), "y")`
+ * returns `Some((a#1, Seq("x", "y")))`.
*/
- private def toSupportedPartitionField(
- transform: Transform,
- relation: DataSourceV2Relation): Option[StructField] = {
- transform match {
- case t: IdentityTransform if t.ref.fieldNames.length == 1 =>
- val colName = t.ref.fieldNames.head
- relation.output
- .find(_.name == colName)
- .map(attr => StructField(colName, attr.dataType, attr.nullable))
- case _ =>
- None
+ private def flattenStructFieldChain(
+ expr: Expression): Option[(AttributeReference, Seq[String])] = {
+ def flatten(e: Expression): Option[(AttributeReference, Seq[String])] = e match {
+ case ar: AttributeReference => Some((ar, Seq.empty))
+ case g: GetStructField =>
+ flatten(g.child).map { case (ar, tail) => (ar, tail :+ g.extractFieldName) }
+ case _ => None
}
+ flatten(expr).filter(_._2.nonEmpty)
}
+ private def pathsMatch(
+ left: Seq[String],
+ right: Seq[String],
+ resolver: (String, String) => Boolean) =
+ left.length == right.length && left.zip(right).forall { case (a, b) => resolver(a, b) }
+
/**
- * Returns true if the given filter expression is safe to push as a partition predicate
- * when using iterative pushdown: it must be deterministic, contain
- * no subquery, and no PythonUDF.
+ * Normally translated filters (postScanFilters) are simple filters that can be
+ * evaluated faster, while the untranslated filters are complicated filters
+ * that take more time to evaluate, so we want to evaluate the translatable
+ * filters first.
*/
- private def isPushablePartitionFilter(f: Expression): Boolean =
- f.deterministic && !SubqueryExpression.hasSubquery(f) && !f.exists(_.isInstanceOf[PythonUDF])
+ private def prioritizeFilters(
+ filters: Seq[Expression],
+ untranslatableFilterSet: ExpressionSet): Seq[Expression] = {
+ val (translatable, untranslatable) = filters.partition(!untranslatableFilterSet.contains(_))
+ translatable ++ untranslatable
+ }
/**
* Pushes down TableSample to the data source Scan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 9a25752ccadac..4a4ccab47cad0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -80,9 +80,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// `pushedFilters` will be pushed down and evaluated in the underlying data sources.
// `postScanFilters` need to be evaluated after the scan.
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
- val partitionSchema = PushDownUtils.getPartitionPredicateSchema(sHolder.relation)
+ val partitionPredicateFields = PushDownUtils.getPartitionPredicateSchema(sHolder.relation)
val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
- sHolder.builder, normalizedFiltersWithoutSubquery, partitionSchema)
+ sHolder.builder, normalizedFiltersWithoutSubquery, partitionPredicateFields)
val pushedFiltersStr = if (pushedFilters.isLeft) {
pushedFilters.swap
.getOrElse(throw new NoSuchElementException("The left node doesn't have pushedFilters"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala
index be6e78fcebcc6..4b70f96fc3eff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, In, PredicateHelpe
import org.apache.spark.sql.connector.catalog.BufferedRows
import org.apache.spark.sql.connector.catalog.InMemoryEnhancedPartitionFilterTable
import org.apache.spark.sql.connector.catalog.InMemoryTableEnhancedPartitionFilterCatalog
-import org.apache.spark.sql.connector.expressions.PartitionColumnReference
+import org.apache.spark.sql.connector.expressions.PartitionFieldReference
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate
import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan
import org.apache.spark.sql.execution.FilterExec
@@ -132,7 +132,7 @@ class DataSourceV2EnhancedPartitionFilterSuite
checkAnswer(df, Seq(Row("a", "x"), Row("b", "y")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("a", "b"))
- assertReferencedPartitionColumnOrdinals(df, Array(0), Array("part_col"))
+ assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
}
}
@@ -197,11 +197,11 @@ class DataSourceV2EnhancedPartitionFilterSuite
checkAnswer(df, Seq(Row("b", "y"), Row("bc", "z")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("b", "bc"))
- assertReferencedPartitionColumnOrdinals(df, Array(0), Array("part_col"))
+ assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
}
}
- test("case 8: Second-pass PartitionPredicate filter works for UDF filter on partition column") {
+ test("case 8: Second-pass PartitionPredicate filter works for UDF filter on partition field") {
withTable(partFilterTableName) {
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
"PARTITIONED BY (part_col)")
@@ -216,11 +216,65 @@ class DataSourceV2EnhancedPartitionFilterSuite
checkAnswer(df, Seq(Row("a", "x"), Row("A", "y")))
assertPushedPartitionPredicates(df, 1)
assertScanReturnsPartitionKeys(df, Set("a", "A"))
- assertReferencedPartitionColumnOrdinals(df, Array(0), Array("part_col"))
+ assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
}
}
- test("referenced partition column ordinals: partition predicate same column twice " +
+ test("nested identity partition: second-pass PartitionPredicate with UDF on nested key") {
+ withTable(partFilterTableName) {
+ sql(s"CREATE TABLE $partFilterTableName " +
+ s"(s struct