diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs index 17c9416d54..f0e6639243 100644 --- a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -19,7 +19,7 @@ use std::vec; use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::expr::ScalarFunction; -use datafusion::logical_expr::{Expr, Like, Operator}; +use datafusion::logical_expr::{BinaryExpr, Expr, Like, Operator}; use datafusion::scalar::ScalarValue; use iceberg::expr::{BinaryExpression, Predicate, PredicateOperator, Reference, UnaryExpression}; use iceberg::spec::{Datum, PrimitiveLiteral}; @@ -225,17 +225,130 @@ fn to_iceberg_operation(op: Operator) -> OpTransformedResult { /// identified by name at runtime, so we need to handle them here. fn scalar_function_to_iceberg_predicate(func_name: &str, args: &[Expr]) -> TransformedResult { match func_name { - // TODO: support complex expression arguments to scalar functions - "isnan" if args.len() == 1 => { - let operand = to_iceberg_predicate(&args[0]); - match operand { - TransformedResult::Column(r) => TransformedResult::Predicate(Predicate::Unary( - UnaryExpression::new(PredicateOperator::IsNan, r), - )), - _ => TransformedResult::NotTransformed, + "isnan" if args.len() == 1 => match resolve_nan_preserving_reference(&args[0]) { + Some(r) => TransformedResult::Predicate(r.is_nan()), + None => TransformedResult::NotTransformed, + }, + _ => TransformedResult::NotTransformed, + } +} + +/// Attempts to resolve a numeric expression argument down to a single column +/// [`Reference`] such that `isnan(arg)` is logically equivalent to +/// `isnan(reference)`. +/// +/// Filter pushdown is reported as `Inexact` (see +/// [`IcebergTableProvider::supports_filters_pushdown`]), so DataFusion +/// re-applies the original predicate after scanning. We therefore only need the +/// pushed-down predicate to be implied by the original filter (it may match +/// extra rows, but must never drop a matching one). Every transformation handled +/// here preserves NaN-ness *exactly* — the result is NaN if and only if the +/// wrapped column is NaN — so both `isnan(arg)` and `NOT isnan(arg)` are sound: +/// +/// * negation: `-x` is NaN iff `x` is NaN +/// * `abs(x)`: `abs(x)` is NaN iff `x` is NaN +/// * casts between numeric types preserve NaN +/// * `x + c`, `c + x`, `x - c`, `c - x` for a finite literal `c` +/// * `x * c`, `c * x`, `x / c` for a finite, non-zero literal `c` +/// +/// Multiplication/division by zero and `c / x` are intentionally rejected: e.g. +/// `x * 0` is NaN when `x` is `±inf`, so it does not imply `x` is NaN. +/// +/// [`IcebergTableProvider::supports_filters_pushdown`]: crate::table::IcebergTableProvider +fn resolve_nan_preserving_reference(expr: &Expr) -> Option { + match expr { + Expr::Column(column) => Some(Reference::new(column.name())), + Expr::Negative(inner) => resolve_nan_preserving_reference(inner), + Expr::Cast(cast) => { + // Casts to date truncate the value and are not numeric, so they + // cannot be treated as NaN-preserving. + if cast.data_type == DataType::Date32 || cast.data_type == DataType::Date64 { + return None; } + resolve_nan_preserving_reference(&cast.expr) } - _ => TransformedResult::NotTransformed, + Expr::ScalarFunction(ScalarFunction { func, args }) + if func.name() == "abs" && args.len() == 1 => + { + resolve_nan_preserving_reference(&args[0]) + } + Expr::BinaryExpr(binary) => resolve_nan_preserving_binary(binary), + _ => None, + } +} + +/// Resolves the column reference from an arithmetic expression that combines a +/// single column with a finite literal while preserving NaN-ness. See +/// [`resolve_nan_preserving_reference`] for the soundness argument. +/// +/// Expressions with column references on both sides (e.g. `(x + 1) * (x - 2)`) +/// are not supported. Handling them safely would require both operands to +/// resolve to the *same* column (`x + y` cannot be expressed as a single +/// `col IS NAN`) and the operator combination itself to be NaN-preserving: +/// `(x + 1) * (x - 2)` is NaN iff `x` is NaN, but `(x + 1) - (x - 2)` is NaN +/// for `x = inf` (`inf - inf`) even though `x` is not. +/// +/// TODO: support NaN-preserving expressions with column references on both +/// sides, see . +fn resolve_nan_preserving_binary(binary: &BinaryExpr) -> Option { + let (left, right) = (&binary.left, &binary.right); + match binary.op { + // `x + c`, `c + x`, `x - c` and `c - x` are NaN iff `x` is NaN, for any + // finite literal `c`. The column may be on either side. + Operator::Plus | Operator::Minus => { + if finite_literal(right).is_some() { + resolve_nan_preserving_reference(left) + } else if finite_literal(left).is_some() { + resolve_nan_preserving_reference(right) + } else { + None + } + } + + // `x * c` and `c * x` are NaN iff `x` is NaN, but only when `c` is + // non-zero. Per IEEE-754: + // - inf is not NaN + // - inf * 0 is NaN + // so multiplying by zero is rejected. The column may be on either side. + Operator::Multiply => { + if matches!(finite_literal(right), Some(c) if c != 0.0) { + resolve_nan_preserving_reference(left) + } else if matches!(finite_literal(left), Some(c) if c != 0.0) { + resolve_nan_preserving_reference(right) + } else { + None + } + } + + // `x / c` is NaN iff `x` is NaN, for a finite non-zero literal `c`. + // `c / x` is rejected and the column must be the dividend (left side). + // Per IEEE-754: + // - 0 is not NaN + // - 0 / 0 is NaN + // so `c / x` is not NaN-preserving. + Operator::Divide => { + if matches!(finite_literal(right), Some(c) if c != 0.0) { + resolve_nan_preserving_reference(left) + } else { + None + } + } + + _ => None, + } +} + +/// Returns the value of `expr` as an `f64` if it is a finite numeric literal +/// (i.e. not a non-literal, non-numeric, or infinite/NaN value). The numeric +/// conversion is delegated to DataFusion's [`ScalarValue::cast_to`]; the value +/// is only used to inspect finiteness and sign (precision loss is irrelevant). +fn finite_literal(expr: &Expr) -> Option { + let Expr::Literal(value, _) = expr else { + return None; + }; + match value.cast_to(&DataType::Float64).ok()? { + ScalarValue::Float64(Some(v)) if v.is_finite() => Some(v), + _ => None, } } @@ -732,11 +845,80 @@ mod tests { assert_eq!(predicate, expected_predicate); } + #[test] + fn test_predicate_conversion_with_isnan_negation() { + // -x is NaN iff x is NaN + let predicate = convert_to_iceberg_predicate("isnan(-qux)").unwrap(); + assert_eq!(predicate, Reference::new("qux").is_nan()); + + let predicate = convert_to_iceberg_predicate("NOT isnan(-qux)").unwrap(); + assert_eq!(predicate, !Reference::new("qux").is_nan()); + } + + #[test] + fn test_predicate_conversion_with_isnan_abs() { + // abs(x) is NaN iff x is NaN + let predicate = convert_to_iceberg_predicate("isnan(abs(qux))").unwrap(); + assert_eq!(predicate, Reference::new("qux").is_nan()); + } + + #[test] + fn test_predicate_conversion_with_isnan_additive() { + // x + c, c + x, x - c, c - x are NaN iff x is NaN (for finite c) + for sql in [ + "isnan(qux + 1)", + "isnan(1 + qux)", + "isnan(qux - 1)", + "isnan(1 - qux)", + "isnan(qux + 1.5)", + ] { + let predicate = convert_to_iceberg_predicate(sql).unwrap(); + assert_eq!(predicate, Reference::new("qux").is_nan(), "sql: {sql}"); + } + } + + #[test] + fn test_predicate_conversion_with_isnan_multiplicative() { + // x * c, c * x, x / c are NaN iff x is NaN (for finite non-zero c) + for sql in ["isnan(qux * 2)", "isnan(2 * qux)", "isnan(qux / 2)"] { + let predicate = convert_to_iceberg_predicate(sql).unwrap(); + assert_eq!(predicate, Reference::new("qux").is_nan(), "sql: {sql}"); + } + } + + #[test] + fn test_predicate_conversion_with_isnan_nested_expr() { + // Nested NaN-preserving transformations resolve to the inner column + let predicate = convert_to_iceberg_predicate("isnan(-(abs(qux) + 1) * 3)").unwrap(); + assert_eq!(predicate, Reference::new("qux").is_nan()); + } + + #[test] + fn test_predicate_conversion_with_isnan_and_other_complex_condition() { + let sql = "isnan(qux + 1) AND foo > 1"; + let predicate = convert_to_iceberg_predicate(sql).unwrap(); + let expected_predicate = Predicate::and( + Reference::new("qux").is_nan(), + Reference::new("foo").greater_than(Datum::long(1)), + ); + assert_eq!(predicate, expected_predicate); + } + #[test] fn test_predicate_conversion_with_isnan_unsupported_arg() { - // isnan on a complex expression (not a bare column) cannot be pushed down - let sql = "isnan(qux + 1)"; - let predicate = convert_to_iceberg_predicate(sql); - assert_eq!(predicate, None); + // Multiplying/dividing by zero does not preserve NaN-ness: `x * 0` is NaN + // when `x` is ±inf, so it cannot be pushed down. + assert_eq!(convert_to_iceberg_predicate("isnan(qux * 0)"), None); + assert_eq!(convert_to_iceberg_predicate("isnan(qux / 0)"), None); + + // `c / x` is not NaN-preserving (e.g. `0 / 0` is NaN while `0` is not). + assert_eq!(convert_to_iceberg_predicate("isnan(1 / qux)"), None); + + // Expressions referencing more than one column cannot be reduced to a + // single column reference. + assert_eq!(convert_to_iceberg_predicate("isnan(qux + foo)"), None); + + // Unknown scalar functions are not pushed down. + assert_eq!(convert_to_iceberg_predicate("isnan(sqrt(qux))"), None); } }