@@ -20,17 +20,22 @@ import scala.collection.JavaConverters._
2020import scala .util .control .NonFatal
2121
2222import org .apache .iceberg .{FileFormat , FileScanTask , MetadataColumns }
23- import org .apache .iceberg .expressions .Expressions
23+ import org .apache .iceberg .expressions .{ And => IcebergAnd , BoundPredicate , Expression => IcebergExpression , Not => IcebergNot , Or => IcebergOr , UnboundPredicate }
2424import org .apache .spark .internal .Logging
2525import org .apache .spark .sql .auron .NativeConverters
26+ import org .apache .spark .sql .catalyst .expressions .{And => SparkAnd , AttributeReference , EqualTo , Expression => SparkExpression , GreaterThan , GreaterThanOrEqual , In , IsNaN , IsNotNull , IsNull , LessThan , LessThanOrEqual , Literal , Not => SparkNot , Or => SparkOr }
2627import org .apache .spark .sql .connector .read .InputPartition
2728import org .apache .spark .sql .execution .datasources .v2 .BatchScanExec
28- import org .apache .spark .sql .types .StructType
29+ import org .apache .spark .sql .internal .SQLConf
30+ import org .apache .spark .sql .types .{BinaryType , DataType , DecimalType , StringType , StructField , StructType }
31+
32+ import org .apache .auron .{protobuf => pb }
2933
3034final case class IcebergScanPlan (
3135 fileTasks : Seq [FileScanTask ],
3236 fileFormat : FileFormat ,
33- readSchema : StructType )
37+ readSchema : StructType ,
38+ pruningPredicates : Seq [pb.PhysicalExprNode ])
3439
3540object IcebergScanSupport extends Logging {
3641
@@ -61,7 +66,7 @@ object IcebergScanSupport extends Logging {
6166 // Empty scan (e.g. empty table) should still build a plan to return no rows.
6267 if (partitions.isEmpty) {
6368 logWarning(s " Native Iceberg scan planned with empty partitions for $scanClassName. " )
64- return Some (IcebergScanPlan (Seq .empty, FileFormat .PARQUET , readSchema))
69+ return Some (IcebergScanPlan (Seq .empty, FileFormat .PARQUET , readSchema, Seq .empty ))
6570 }
6671
6772 val icebergPartitions = partitions.flatMap(icebergPartition)
@@ -77,11 +82,6 @@ object IcebergScanSupport extends Logging {
7782 return None
7883 }
7984
80- // Residual filters require row-level evaluation, not supported in native scan.
81- if (! fileTasks.forall(task => Expressions .alwaysTrue().equals(task.residual()))) {
82- return None
83- }
84-
8585 // Native scan handles a single file format; mixed formats must fallback.
8686 val formats = fileTasks.map(_.file().format()).distinct
8787 if (formats.size > 1 ) {
@@ -93,7 +93,9 @@ object IcebergScanSupport extends Logging {
9393 return None
9494 }
9595
96- Some (IcebergScanPlan (fileTasks, format, readSchema))
96+ val pruningPredicates = collectPruningPredicates(scan.asInstanceOf [AnyRef ], readSchema)
97+
98+ Some (IcebergScanPlan (fileTasks, format, readSchema, pruningPredicates))
9799 }
98100
99101 private def hasMetadataColumns (schema : StructType ): Boolean =
@@ -188,4 +190,240 @@ object IcebergScanSupport extends Logging {
188190 None
189191 }
190192 }
193+
194+ private def collectPruningPredicates (
195+ scan : AnyRef ,
196+ readSchema : StructType ): Seq [pb.PhysicalExprNode ] = {
197+ scanFilterExpressions(scan).flatMap { expr =>
198+ convertIcebergFilterExpression(expr, readSchema) match {
199+ case Some (converted) =>
200+ Some (NativeConverters .convertScanPruningExpr(converted))
201+ case None =>
202+ logDebug(s " Skip unsupported Iceberg pruning expression: $expr" )
203+ None
204+ }
205+ }
206+ }
207+
208+ private def scanFilterExpressions (scan : AnyRef ): Seq [IcebergExpression ] = {
209+ invokeDeclaredMethod(scan, " filterExpressions" ) match {
210+ case Some (values : java.util.Collection [_]) =>
211+ values.asScala.collect { case expr : IcebergExpression => expr }.toSeq
212+ case Some (values : Seq [_]) =>
213+ values.collect { case expr : IcebergExpression => expr }
214+ case _ =>
215+ Seq .empty
216+ }
217+ }
218+
219+ private def invokeDeclaredMethod (target : AnyRef , methodName : String ): Option [Any ] = {
220+ try {
221+ var cls : Class [_] = target.getClass
222+ while (cls != null ) {
223+ cls.getDeclaredMethods.find(_.getName == methodName) match {
224+ case Some (method) =>
225+ method.setAccessible(true )
226+ return Some (method.invoke(target))
227+ case None =>
228+ cls = cls.getSuperclass
229+ }
230+ }
231+ None
232+ } catch {
233+ case NonFatal (t) =>
234+ logDebug(s " Failed to invoke $methodName on ${target.getClass.getName}. " , t)
235+ None
236+ }
237+ }
238+
239+ private def convertIcebergFilterExpression (
240+ expr : IcebergExpression ,
241+ readSchema : StructType ): Option [SparkExpression ] = {
242+ expr match {
243+ case and : IcebergAnd =>
244+ for {
245+ left <- convertIcebergFilterExpression(and.left(), readSchema)
246+ right <- convertIcebergFilterExpression(and.right(), readSchema)
247+ } yield SparkAnd (left, right)
248+ case or : IcebergOr =>
249+ for {
250+ left <- convertIcebergFilterExpression(or.left(), readSchema)
251+ right <- convertIcebergFilterExpression(or.right(), readSchema)
252+ } yield SparkOr (left, right)
253+ case not : IcebergNot =>
254+ convertIcebergFilterExpression(not.child(), readSchema).map(SparkNot )
255+ case predicate : UnboundPredicate [_] =>
256+ convertUnboundPredicate(predicate, readSchema)
257+ case predicate : BoundPredicate [_] =>
258+ convertBoundPredicate(predicate, readSchema)
259+ case _ =>
260+ expr.op() match {
261+ case org.apache.iceberg.expressions.Expression .Operation .TRUE =>
262+ Some (Literal (true ))
263+ case org.apache.iceberg.expressions.Expression .Operation .FALSE =>
264+ Some (Literal (false ))
265+ case _ =>
266+ None
267+ }
268+ }
269+ }
270+
271+ private def convertUnboundPredicate (
272+ predicate : UnboundPredicate [_],
273+ readSchema : StructType ): Option [SparkExpression ] = {
274+ findField(predicate.ref().name(), readSchema).flatMap { field =>
275+ val attr = toAttribute(field)
276+ val op = predicate.op()
277+
278+ op match {
279+ case org.apache.iceberg.expressions.Expression .Operation .IS_NULL =>
280+ Some (IsNull (attr))
281+ case org.apache.iceberg.expressions.Expression .Operation .NOT_NULL =>
282+ Some (IsNotNull (attr))
283+ case org.apache.iceberg.expressions.Expression .Operation .IS_NAN =>
284+ Some (IsNaN (attr))
285+ case org.apache.iceberg.expressions.Expression .Operation .NOT_NAN =>
286+ Some (SparkNot (IsNaN (attr)))
287+ case org.apache.iceberg.expressions.Expression .Operation .IN =>
288+ convertInPredicate(
289+ attr,
290+ field.dataType,
291+ predicate.literals().asScala.map(_.value()).toSeq)
292+ case org.apache.iceberg.expressions.Expression .Operation .NOT_IN =>
293+ convertInPredicate(
294+ attr,
295+ field.dataType,
296+ predicate.literals().asScala.map(_.value()).toSeq).map(SparkNot )
297+ case _ =>
298+ convertBinaryPredicate(attr, field.dataType, op, predicate.literal().value())
299+ }
300+ }
301+ }
302+
303+ private def convertBoundPredicate (
304+ predicate : BoundPredicate [_],
305+ readSchema : StructType ): Option [SparkExpression ] = {
306+ findField(predicate.ref().name(), readSchema).flatMap { field =>
307+ val attr = toAttribute(field)
308+ val op = predicate.op()
309+
310+ if (predicate.isUnaryPredicate()) {
311+ op match {
312+ case org.apache.iceberg.expressions.Expression .Operation .IS_NULL =>
313+ Some (IsNull (attr))
314+ case org.apache.iceberg.expressions.Expression .Operation .NOT_NULL =>
315+ Some (IsNotNull (attr))
316+ case org.apache.iceberg.expressions.Expression .Operation .IS_NAN =>
317+ Some (IsNaN (attr))
318+ case org.apache.iceberg.expressions.Expression .Operation .NOT_NAN =>
319+ Some (SparkNot (IsNaN (attr)))
320+ case _ =>
321+ None
322+ }
323+ } else if (predicate.isLiteralPredicate()) {
324+ val literalValue = predicate.asLiteralPredicate().literal().value()
325+ op match {
326+ case _ =>
327+ convertBinaryPredicate(attr, field.dataType, op, literalValue)
328+ }
329+ } else if (predicate.isSetPredicate()) {
330+ val values = predicate.asSetPredicate().literalSet().asScala.toSeq
331+ op match {
332+ case org.apache.iceberg.expressions.Expression .Operation .IN =>
333+ convertInPredicate(attr, field.dataType, values)
334+ case org.apache.iceberg.expressions.Expression .Operation .NOT_IN =>
335+ convertInPredicate(attr, field.dataType, values).map(SparkNot )
336+ case _ =>
337+ None
338+ }
339+ } else {
340+ None
341+ }
342+ }
343+ }
344+
345+ private def convertBinaryPredicate (
346+ attr : AttributeReference ,
347+ dataType : DataType ,
348+ op : org.apache.iceberg.expressions.Expression .Operation ,
349+ literalValue : Any ): Option [SparkExpression ] = {
350+ if (! supportsScanPruningLiteralType(dataType)) {
351+ return None
352+ }
353+ toLiteral(literalValue, dataType).flatMap { literal =>
354+ op match {
355+ case org.apache.iceberg.expressions.Expression .Operation .EQ =>
356+ Some (EqualTo (attr, literal))
357+ case org.apache.iceberg.expressions.Expression .Operation .NOT_EQ =>
358+ Some (SparkNot (EqualTo (attr, literal)))
359+ case org.apache.iceberg.expressions.Expression .Operation .LT =>
360+ Some (LessThan (attr, literal))
361+ case org.apache.iceberg.expressions.Expression .Operation .LT_EQ =>
362+ Some (LessThanOrEqual (attr, literal))
363+ case org.apache.iceberg.expressions.Expression .Operation .GT =>
364+ Some (GreaterThan (attr, literal))
365+ case org.apache.iceberg.expressions.Expression .Operation .GT_EQ =>
366+ Some (GreaterThanOrEqual (attr, literal))
367+ case _ =>
368+ None
369+ }
370+ }
371+ }
372+
373+ private def convertInPredicate (
374+ attr : AttributeReference ,
375+ dataType : DataType ,
376+ values : Seq [Any ]): Option [SparkExpression ] = {
377+ if (! supportsScanPruningLiteralType(dataType)) {
378+ return None
379+ }
380+ val literals = values.map(toLiteral(_, dataType))
381+ if (literals.forall(_.nonEmpty)) {
382+ Some (In (attr, literals.flatten))
383+ } else {
384+ None
385+ }
386+ }
387+
388+ private def supportsScanPruningLiteralType (dataType : DataType ): Boolean = {
389+ dataType match {
390+ case StringType | BinaryType => false
391+ case _ : DecimalType => false
392+ case _ => true
393+ }
394+ }
395+
396+ private def toLiteral (value : Any , dataType : DataType ): Option [Literal ] = {
397+ if (value == null ) {
398+ return Some (Literal .create(null , dataType))
399+ }
400+ dataType match {
401+ case _ : DecimalType =>
402+ None
403+ case BinaryType =>
404+ value match {
405+ case bytes : Array [Byte ] =>
406+ Some (Literal (bytes, BinaryType ))
407+ case byteBuffer : java.nio.ByteBuffer =>
408+ val duplicated = byteBuffer.duplicate()
409+ val bytes = new Array [Byte ](duplicated.remaining())
410+ duplicated.get(bytes)
411+ Some (Literal (bytes, BinaryType ))
412+ case _ =>
413+ None
414+ }
415+ case StringType =>
416+ Some (Literal .create(value.toString, StringType ))
417+ case _ =>
418+ Some (Literal .create(value, dataType))
419+ }
420+ }
421+
422+ private def toAttribute (field : StructField ): AttributeReference =
423+ AttributeReference (field.name, field.dataType, nullable = true )()
424+
425+ private def findField (name : String , readSchema : StructType ): Option [StructField ] = {
426+ val resolver = SQLConf .get.resolver
427+ readSchema.fields.find(field => resolver(field.name, name))
428+ }
191429}
0 commit comments