Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@
import org.apache.spark.sql.connector.catalog.Table;

/**
* A reference to a partition column in {@link Table#partitioning()}.
* A reference to a partition field in {@link Table#partitioning()}.
* <p>
* {@link #fieldNames()} returns the partition column name (or names) as reported by
* {@link #fieldNames()} returns the partition field name (or names) as reported by
* the table's partition schema.
* {@link #ordinal()} returns the 0-based position in {@link Table#partitioning()}.
*
* @since 4.2.0
*/
@Evolving
public interface PartitionColumnReference extends NamedReference {
public interface PartitionFieldReference extends NamedReference {

/**
* Returns the 0-based ordinal of this partition column in {@link Table#partitioning()}.
* Returns the 0-based ordinal of this partition field in {@link Table#partitioning()}.
*/
int ordinal();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.PartitionColumnReference;

import static org.apache.spark.sql.connector.expressions.Expression.EMPTY_EXPRESSION;
import org.apache.spark.sql.connector.expressions.PartitionFieldReference;

/**
* Represents a partition predicate that can be evaluated using {@link Table#partitioning()}.
Expand All @@ -47,17 +45,17 @@ protected PartitionPredicate() {
/**
* {@inheritDoc}
* <p>
* For PartitionPredicate, returns {@link PartitionColumnReference} instances that identify
* the partition columns (from {@link Table#partitioning()}) referenced by this predicate.
* Each reference's {@link PartitionColumnReference#fieldNames()} gives the partition column
* name; {@link PartitionColumnReference#ordinal()} gives the 0-based position in
* For PartitionPredicate, returns {@link PartitionFieldReference} instances that identify
* the partition fields (from {@link Table#partitioning()}) referenced by this predicate.
* Each reference's {@link PartitionFieldReference#fieldNames()} gives the partition field
* name; {@link PartitionFieldReference#ordinal()} gives the 0-based position in
* {@link Table#partitioning()}.
* <p>
* <b>Example:</b> Suppose {@code Table.partitioning()} returns three partition
* transforms: {@code [years(ts), months(ts), bucket(32, id)]} with ordinals 0, 1, 2.
* Each {@link PartitionColumnReference} has {@link PartitionColumnReference#fieldNames()}
* Each {@link PartitionFieldReference} has {@link PartitionFieldReference#fieldNames()}
* (the transform display name, e.g. {@code years(ts)}) and
* {@link PartitionColumnReference#ordinal()}:
* {@link PartitionFieldReference#ordinal()}:
* <ul>
* <li>{@code years(ts) = 2026} returns one reference: (fieldNames=[years(ts)], ordinal=0).</li>
* <li>{@code years(ts) = 2026 and months(ts) = 01} returns two references:
Expand All @@ -72,7 +70,7 @@ protected PartitionPredicate() {
* partitioned data) to Spark for post-scan filter, while predicates that reference only
* initially-added partition transforms may be fully pushed.
*
* @return array of partition column references
* @return array of partition field references
*/
@Override
public abstract NamedReference[] references();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.PartitionColumnReference;
import org.apache.spark.sql.connector.expressions.PartitionFieldReference;
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate;
import org.apache.spark.sql.connector.expressions.filter.Predicate;

Expand Down Expand Up @@ -54,8 +54,8 @@ public interface SupportsPushDownV2Filters extends ScanBuilder {
* {@link #pushedPredicates()} can return predicates from all of them.
* <p>
* For each {@link PartitionPredicate}, the implementation can use
* {@link PartitionPredicate#references()} (each {@link PartitionColumnReference} has
* {@link PartitionColumnReference#ordinal()}) to decide whether to return it for post-scan
* {@link PartitionPredicate#references()} (each {@link PartitionFieldReference} has
* {@link PartitionFieldReference#ordinal()}) to decide whether to return it for post-scan
* filtering. For example, data sources with
* partition spec evolution may return predicates that reference later-added partition
* transforms (incompletely partitioned data) so Spark evaluates them after the scan, while
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.internal.connector

import org.apache.spark.sql.connector.expressions.PartitionColumnReference
import org.apache.spark.sql.connector.expressions.PartitionFieldReference

/**
* Implementation of [[PartitionColumnReference]] that carries the position ordinal in
* Table.partitioning() and the partition column name(s) for that position.
* Implementation of [[PartitionFieldReference]] that carries the position ordinal in
* Table.partitioning() and the partition field name(s) for that position.
*/
private[connector] case class PartitionColumnReferenceImpl(
private[connector] case class PartitionFieldReferenceImpl(
ordinal: Int,
fieldNames: Array[String])
extends PartitionColumnReference
extends PartitionFieldReference
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal.connector

import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.types.StructField

/**
* Metadata for one partition field.
*
* @param structField the Catalyst field describing the partition column (name, type,
* nullability). For nested columns the name is the dotted path
* (e.g. `"s.tz"`).
* @param identityRef the [[NamedReference]] from the transform target column
* (e.g. `FieldReference(Seq("s", "tz"))`).
*/
case class PartitionPredicateField(structField: StructField, identityRef: NamedReference)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Expression => CatalystExpression, Predicate => CatalystPredicate}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate

Expand All @@ -30,30 +31,33 @@ import org.apache.spark.sql.connector.expressions.filter.PartitionPredicate
*/
class PartitionPredicateImpl private (
private val catalystExpr: CatalystExpression,
private val partitionSchema: Seq[AttributeReference])
private val partitionFields: Seq[PartitionPredicateField])
extends PartitionPredicate with Logging {

@transient private lazy val partitionAttrs: Seq[AttributeReference] =
partitionFields.map(c => DataTypeUtils.toAttribute(c.structField))

/** The wrapped partition filter Catalyst Expression. */
def expression: CatalystExpression = catalystExpr

/** Bound predicate, computed once and reused for all partition rows. */
@transient private lazy val boundPredicate: InternalRow => Boolean = {
val boundExpr = catalystExpr.transform {
case a: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == a.name)
require(index >= 0, s"Column ${a.name} not found in partition schema")
BoundReference(index, partitionSchema(index).dataType, nullable = a.nullable)
val index = partitionAttrs.indexWhere(_.name == a.name)
require(index >= 0, s"Field ${a.name} not found in partition schema")
BoundReference(index, partitionAttrs(index).dataType, nullable = a.nullable)
}
val predicate = CatalystPredicate.createInterpreted(boundExpr)
predicate.eval
}

override def eval(partitionValues: InternalRow): Boolean = {
if (partitionValues.numFields != partitionSchema.length) {
if (partitionValues.numFields != partitionFields.length) {
logWarning(
log"Cannot evaluate partition predicate ${MDC(LogKeys.EXPR, catalystExpr.sql)}: " +
log"partition value field count (${MDC(LogKeys.COUNT, partitionValues.numFields)}) " +
log"does not match schema (${MDC(LogKeys.NUM_PARTITIONS, partitionSchema.length)}). " +
log"does not match schema (${MDC(LogKeys.NUM_PARTITIONS, partitionFields.length)}). " +
log"Including partition in scan result to avoid incorrect filtering.")
return true
}
Expand All @@ -72,41 +76,54 @@ class PartitionPredicateImpl private (

@transient override lazy val references: Array[NamedReference] = {
val refNames = catalystExpr.references.map(_.name).toSet
partitionSchema.zipWithIndex
partitionAttrs.zipWithIndex
.filter { case (attr, _) => refNames.contains(attr.name) }
.map { case (attr, ordinal) => PartitionColumnReferenceImpl(ordinal, Array(attr.name)) }
.map { case (_, ordinal) =>
PartitionFieldReferenceImpl(ordinal,
partitionFields(ordinal).identityRef.fieldNames())
}
.toArray
}

override def equals(obj: Any): Boolean = obj match {
case other: PartitionPredicateImpl =>
catalystExpr.semanticEquals(other.catalystExpr) && partitionSchema == other.partitionSchema
catalystExpr.semanticEquals(other.catalystExpr) &&
partitionFields == other.partitionFields
case _ => false
}

override def hashCode(): Int = {
31 * catalystExpr.semanticHash() + partitionSchema.hashCode()
31 * catalystExpr.semanticHash() + partitionFields.hashCode()
}

override def toString(): String = s"PartitionPredicate(${catalystExpr.sql})"
}

object PartitionPredicateImpl {

def apply(
def apply(catalystExpr: CatalystExpression,
partitionFields: Seq[PartitionPredicateField])
: PartitionPredicateImpl = {
validateAndCreate(catalystExpr, partitionFields)
}

private def validateAndCreate(
catalystExpr: CatalystExpression,
partitionSchema: Seq[AttributeReference]): PartitionPredicateImpl = {
if (partitionSchema.isEmpty) {
partitionFields: Seq[PartitionPredicateField])
: PartitionPredicateImpl = {
if (partitionFields.isEmpty) {
throw SparkException.internalError(
s"Cannot evaluate partition predicate ${catalystExpr.sql}: partition schema is empty")
s"Cannot evaluate partition predicate ${catalystExpr.sql}: partition fields are empty")
}
val partitionNames = partitionSchema.map(_.name).toSet
val partitionNames = partitionFields.map(_.structField.name).toSet
val refNames = catalystExpr.references.map(_.name).toSet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's anti-pattern to compare qualified names as single strings. I think one side is from the nested cols in partition predicates, the other side is reported by v2 table. How does v2 table report partition cols?

Copy link
Member Author

@szehon-ho szehon-ho Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the comparison is between

  1. catalyst partition filter
  2. v2 table partition cols

Both are flattened (ie turn into "a.b.c"):

  1. using normalizePartitionFilters() which returns AttributeReference with flattened name
  2. using resolveIdentityPartitionFIeld() which returns StructField with flattened name.

V2Table reports it via Transform which has transform.ref.fieldNames() which is Seq[String]. But I do need to flatten it for comparison, do you have any other thoughts?

Another reason for flatten it is later I need to pass the partition schema as StructType to DataSourceUtils.getPartitionFiltersAndDataFilters. That has some valuable logic there (eg, extracting more partition filters) that I did not want to re-implement.

if (!refNames.subsetOf(partitionNames)) {
val refsStr = refNames.mkString(", ")
val fieldsStr = partitionNames.mkString(", ")
throw SparkException.internalError(
s"Cannot evaluate partition predicate ${catalystExpr.sql}: expression references " +
s"${refNames.mkString(", ")} not all in partition columns ${partitionNames.mkString(", ")}")
s"$refsStr not all in partition fields $fieldsStr")
}
new PartitionPredicateImpl(catalystExpr, partitionSchema)
new PartitionPredicateImpl(catalystExpr, partitionFields)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class InMemoryEnhancedPartitionFilterTable(
override def supportsIterativePushdown(): Boolean = true

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val partNames = InMemoryEnhancedPartitionFilterTable.this.partCols.flatMap(_.toSeq).toSet
val partPaths = InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.mkString(".")).toSet
def referencesOnlyPartitionCols(p: Predicate): Boolean =
p.references().forall(ref => partNames.contains(ref.fieldNames().mkString(".")))
p.references().forall(ref => partPaths.contains(ref.fieldNames().mkString(".")))
def referencesOnlyDataCols(p: Predicate): Boolean =
p.references().forall(ref => !partNames.contains(ref.fieldNames().mkString(".")))
p.references().forall(ref => !partPaths.contains(ref.fieldNames().mkString(".")))

val returned = ArrayBuffer.empty[Predicate]

Expand Down Expand Up @@ -120,10 +120,11 @@ class InMemoryEnhancedPartitionFilterTable(
val partNames =
InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.toSeq.quoted)
.toImmutableArraySeq
val partNamesSet = InMemoryEnhancedPartitionFilterTable.this.partCols.flatMap(_.toSeq).toSet
val partPathSet =
InMemoryEnhancedPartitionFilterTable.this.partCols.map(_.mkString(".")).toSet
// Only partition predicates can be used for partition key filtering (filtersToKeys).
val firstPassPartitionPredicates = firstPassPushedPredicates.filter { p =>
p.references().forall(ref => partNamesSet.contains(ref.fieldNames().mkString(".")))
p.references().forall(ref => partPathSet.contains(ref.fieldNames().mkString(".")))
}
val allKeys = allPartitions.map(_.asInstanceOf[BufferedRows].key)
val matchingKeys = InMemoryTableWithV2Filter.filtersToKeys(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package org.apache.spark.sql.internal.connector
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal}
import org.apache.spark.sql.connector.expressions.PartitionColumnReference
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.catalyst.expressions.{GreaterThan, Literal}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.expressions.{FieldReference, PartitionFieldReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField}
import org.apache.spark.unsafe.types.UTF8String

class PartitionPredicateImplSuite extends SparkFunSuite {

Expand All @@ -38,12 +40,25 @@ class PartitionPredicateImplSuite extends SparkFunSuite {
checkPartitionPredicateImplAfterSerialization(serializer)
}

test("Kryo: nested partition path in references survives round-trip") {
val conf = new SparkConf()
val serializer = new KryoSerializer(conf).newInstance()
checkNestedPartitionPathReferencesAfterSerialization(serializer)
}

test("Java serialization: nested partition path in references survives round-trip") {
val conf = new SparkConf()
val serializer = new JavaSerializer(conf).newInstance()
checkNestedPartitionPathReferencesAfterSerialization(serializer)
}

private def checkPartitionPredicateImplAfterSerialization(
serializer: SerializerInstance): Unit = {
val partitionSchema = Seq(AttributeReference("p", IntegerType)())
val ref = AttributeReference("p", IntegerType)()
val field = StructField("p", IntegerType, nullable = true)
val ref = DataTypeUtils.toAttribute(field)
val expr = GreaterThan(ref, Literal(5))
val predicate = PartitionPredicateImpl(expr, partitionSchema)
val fields = Seq(PartitionPredicateField(field, FieldReference(Seq("p"))))
val predicate = PartitionPredicateImpl(expr, fields)

val deserialized = serializer.deserialize[PartitionPredicateImpl](
serializer.serialize(predicate))
Expand All @@ -59,10 +74,38 @@ class PartitionPredicateImplSuite extends SparkFunSuite {
assert(deserialized.equals(predicate))
}

private def checkNestedPartitionPathReferencesAfterSerialization(
serializer: SerializerInstance): Unit = {
val field = StructField("ts.timezone", StringType, nullable = false)
val ref = DataTypeUtils.toAttribute(field)
val expr = GreaterThan(ref, Literal("x"))
val fields = Seq(PartitionPredicateField(field, FieldReference(Seq("ts", "timezone"))))
val predicate = PartitionPredicateImpl(expr, fields)

val deserialized = serializer.deserialize[PartitionPredicateImpl](
serializer.serialize(predicate))

assert(deserialized.eval(InternalRow(UTF8String.fromString("z"))) === true)
assert(deserialized.eval(InternalRow(UTF8String.fromString("a"))) === false)

val expectedRefs = Seq((0, Seq("ts", "timezone")))
assert(partitionRefDetails(predicate.references.toSeq) === expectedRefs)
assert(partitionRefDetails(deserialized.references.toSeq) === expectedRefs)

assert(deserialized.equals(predicate))
}

private def partitionRefDetails(refs: Seq[AnyRef]): Seq[(Int, Seq[String])] = refs.map {
case r: PartitionFieldReference =>
(r.ordinal(), r.fieldNames().toIndexedSeq)
case other =>
fail(s"Expected PartitionFieldReference, got ${other.getClass.getName}: $other")
}

private def refsWithOrdinals(refs: Seq[AnyRef]): Seq[(String, Int)] = refs.map {
case r: PartitionColumnReference =>
case r: PartitionFieldReference =>
(r.fieldNames().mkString("."), r.ordinal())
case other =>
fail(s"Expected PartitionColumnReference, got ${other.getClass.getName}: $other")
fail(s"Expected PartitionFieldReference, got ${other.getClass.getName}: $other")
}
}
Loading