diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 3a255ae05f76..f055c8298bc1 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -25,7 +25,7 @@ use std::hash::Hash; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use arrow::datatypes::{DataType, FieldRef, SchemaRef}; use datafusion_common::{ DataFusionError, Result, ScalarValue, exec_err, metadata::FieldMetadata, @@ -34,11 +34,10 @@ use datafusion_common::{ }; use datafusion_functions::core::getfield::GetFieldFunc; use datafusion_physical_expr::PhysicalExprSimplifier; -use datafusion_physical_expr::expressions::CastColumnExpr; use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; use datafusion_physical_expr::{ ScalarFunctionExpr, - expressions::{self, Column}, + expressions::{self, CastExpr, Column}, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use itertools::Itertools; @@ -423,13 +422,12 @@ impl DefaultPhysicalExprAdapterRewriter { ))); }; - if resolved_column.index() == column.index() - && logical_field == physical_field.as_ref() - { - return Ok(Transformed::no(expr)); - } + let fields_match = logical_field == physical_field.as_ref(); + if fields_match { + if resolved_column.index() == column.index() { + return Ok(Transformed::no(expr)); + } - if logical_field == physical_field.as_ref() { // If the fields match (including metadata/nullability), we can use the column as is return Ok(Transformed::yes(Arc::new(resolved_column))); } @@ -439,7 +437,25 @@ impl DefaultPhysicalExprAdapterRewriter { // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` // since that's much cheaper to evalaute. // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 - self.create_cast_column_expr(resolved_column, physical_field, logical_field) + validate_data_type_compatibility( + resolved_column.name(), + physical_field.data_type(), + logical_field.data_type(), + ) + .map_err(|e| { + DataFusionError::Execution(format!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}", + resolved_column.name(), + physical_field.data_type(), + logical_field.data_type() + )) + })?; + + Ok(Transformed::yes(Arc::new(CastExpr::new_with_target_field( + Arc::new(resolved_column), + Arc::new(logical_field.clone()), + None, + )))) } /// Resolves a logical column to the corresponding physical column and field. @@ -465,48 +481,13 @@ impl DefaultPhysicalExprAdapterRewriter { Column::new_with_schema(column.name(), self.physical_file_schema.as_ref())? }; - Ok(Some(( - column, - Arc::new( - self.physical_file_schema - .field(physical_column_index) - .clone(), - ), - ))) - } - - /// Validates type compatibility and creates a CastColumnExpr if needed. - /// - /// Checks whether the physical field can be cast to the logical field type, - /// handling both struct and scalar types. Returns a CastColumnExpr with the - /// appropriate configuration. - fn create_cast_column_expr( - &self, - column: Column, - physical_field: FieldRef, - logical_field: &Field, - ) -> Result>> { - validate_data_type_compatibility( - column.name(), - physical_field.data_type(), - logical_field.data_type(), - ) - .map_err(|e| - DataFusionError::Execution(format!( - "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}", - column.name(), - physical_field.data_type(), - logical_field.data_type() - )))?; - - let cast_expr = Arc::new(CastColumnExpr::new( - Arc::new(column), - physical_field, - Arc::new(logical_field.clone()), - None, - )); + let physical_field = Arc::new( + self.physical_file_schema + .field(physical_column_index) + .clone(), + ); - Ok(Transformed::yes(cast_expr)) + Ok(Some((column, physical_field))) } } @@ -652,10 +633,40 @@ mod tests { Array, BooleanArray, GenericListArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions, StringArray, StringViewArray, StructArray, }; - use arrow::datatypes::{Fields, Schema}; + use arrow::datatypes::{Field, Fields, Schema}; use datafusion_common::{assert_contains, record_batch}; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{Column, Literal, col, lit}; + use datafusion_physical_expr::expressions::{Column, Literal, col}; + + fn assert_cast_expr(expr: &Arc) -> &CastExpr { + expr.as_any() + .downcast_ref::() + .expect("Expected CastExpr") + } + + fn assert_cast_column(cast_expr: &CastExpr, name: &str, index: usize) { + let inner_col = cast_expr + .expr() + .as_any() + .downcast_ref::() + .expect("Expected inner Column"); + assert_eq!(inner_col.name(), name); + assert_eq!(inner_col.index(), index); + } + + fn stale_index_cast_schemas() -> (SchemaRef, SchemaRef) { + let physical_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Binary, true), + Field::new("a", DataType::Int32, false), + ])); + + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Binary, true), + ])); + + (logical_schema, physical_schema) + } fn create_test_schema() -> (Schema, Schema) { let physical_schema = Schema::new(vec![ @@ -685,7 +696,7 @@ mod tests { let result = adapter.rewrite(column_expr).unwrap(); // Should be wrapped in a cast expression - assert!(result.as_any().downcast_ref::().is_some()); + assert!(result.as_any().downcast_ref::().is_some()); } #[test] @@ -702,24 +713,19 @@ mod tests { .unwrap(); let result = adapter.rewrite(Arc::new(Column::new("a", 0)))?; - let cast = result - .as_any() - .downcast_ref::() - .expect("Expected CastColumnExpr"); - assert_eq!(cast.target_field().data_type(), &DataType::Int64); - assert!(!cast.target_field().is_nullable()); + // Ensure the expression preserves the logical field nullability/metadata. + let return_field = result.return_field(physical_schema.as_ref())?; + assert_eq!(return_field.data_type(), &DataType::Int64); + assert!(!return_field.is_nullable()); assert_eq!( - cast.target_field() + return_field .metadata() .get("logical_meta") .map(String::as_str), Some("1") ); - // Ensure the expression reports the logical nullability regardless of input schema - assert!(!result.nullable(physical_schema.as_ref())?); - Ok(()) } @@ -750,33 +756,35 @@ mod tests { ); let result = adapter.rewrite(Arc::new(expr)).unwrap(); - println!("Rewritten expression: {result}"); - - let expected = expressions::BinaryExpr::new( - Arc::new(CastColumnExpr::new( - Arc::new(Column::new("a", 0)), - Arc::new(Field::new("a", DataType::Int32, false)), - Arc::new(Field::new("a", DataType::Int64, false)), - None, - )), - Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int64(Some(5)))), - ); - let expected = Arc::new(expressions::BinaryExpr::new( - Arc::new(expected), - Operator::Or, - Arc::new(expressions::BinaryExpr::new( - lit(ScalarValue::Float64(None)), // c is missing, so it becomes null - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Float64(Some(0.0)))), - )), - )) as Arc; + let outer = result + .as_any() + .downcast_ref::() + .expect("Expected outer BinaryExpr"); + assert_eq!(*outer.op(), Operator::Or); - assert_eq!( - result.to_string(), - expected.to_string(), - "The rewritten expression did not match the expected output" - ); + let left = outer + .left() + .as_any() + .downcast_ref::() + .expect("Expected left BinaryExpr"); + assert_eq!(*left.op(), Operator::Plus); + + let left_cast = assert_cast_expr(left.left()); + assert_eq!(left_cast.target_field().data_type(), &DataType::Int64); + assert_cast_column(left_cast, "a", 0); + + let right = outer + .right() + .as_any() + .downcast_ref::() + .expect("Expected right BinaryExpr"); + assert_eq!(*right.op(), Operator::Gt); + let null_literal = right + .left() + .as_any() + .downcast_ref::() + .expect("Expected null literal"); + assert_eq!(*null_literal.value(), ScalarValue::Float64(None)); } #[test] @@ -841,17 +849,6 @@ mod tests { let result = adapter.rewrite(column_expr).unwrap(); - let physical_struct_fields: Fields = vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ] - .into(); - let physical_field = Arc::new(Field::new( - "data", - DataType::Struct(physical_struct_fields), - false, - )); - let logical_struct_fields: Fields = vec![ Field::new("id", DataType::Int64, false), Field::new("name", DataType::Utf8View, true), @@ -863,9 +860,8 @@ mod tests { false, )); - let expected = Arc::new(CastColumnExpr::new( + let expected = Arc::new(CastExpr::new_with_target_field( Arc::new(Column::new("data", 0)), - physical_field, logical_field, None, )) as Arc; @@ -1663,8 +1659,7 @@ mod tests { Field::new("b", DataType::Utf8, true), ]); - let factory = DefaultPhysicalExprAdapterFactory; - let adapter = factory + let adapter = DefaultPhysicalExprAdapterFactory .create(Arc::new(logical_schema), Arc::new(physical_schema)) .unwrap(); @@ -1673,20 +1668,11 @@ mod tests { let result = adapter.rewrite(column_expr).unwrap(); - // Should be a CastColumnExpr - let cast_expr = result - .as_any() - .downcast_ref::() - .expect("Expected CastColumnExpr"); + // Should be a CastExpr + let cast_expr = assert_cast_expr(&result); // Verify the inner column points to the correct physical index (1) - let inner_col = cast_expr - .expr() - .as_any() - .downcast_ref::() - .expect("Expected inner Column"); - assert_eq!(inner_col.name(), "a"); - assert_eq!(inner_col.index(), 1); // Physical index is 1 + assert_cast_column(cast_expr, "a", 1); // Verify cast types assert_eq!( @@ -1696,41 +1682,17 @@ mod tests { } #[test] - fn test_create_cast_column_expr_uses_name_lookup_not_column_index() { - // Physical schema has column `a` at index 1; index 0 is an incompatible type. - let physical_schema = Arc::new(Schema::new(vec![ - Field::new("b", DataType::Binary, true), - Field::new("a", DataType::Int32, false), - ])); - - let logical_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, false), - Field::new("b", DataType::Binary, true), - ])); - - let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: Arc::clone(&logical_schema), - physical_file_schema: Arc::clone(&physical_schema), - }; + fn test_rewrite_resolves_physical_column_by_name_before_casting() { + let (logical_schema, physical_schema) = stale_index_cast_schemas(); + let adapter = DefaultPhysicalExprAdapterFactory + .create(logical_schema, physical_schema) + .unwrap(); // Deliberately provide the wrong index for column `a`. // Regression: this must still resolve against physical field `a` by name. - let transformed = rewriter - .create_cast_column_expr( - Column::new("a", 0), - Arc::new(physical_schema.field_with_name("a").unwrap().clone()), - logical_schema.field_with_name("a").unwrap(), - ) - .unwrap(); - - let cast_expr = transformed - .data - .as_any() - .downcast_ref::() - .expect("Expected CastColumnExpr"); - - assert_eq!(cast_expr.input_field().name(), "a"); - assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32); + let rewritten = adapter.rewrite(Arc::new(Column::new("a", 0))).unwrap(); + let cast_expr = assert_cast_expr(&rewritten); + assert_cast_column(cast_expr, "a", 1); assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64); } }