diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 2c2ff6d48aecc..77135d73373f6 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1389,6 +1389,7 @@ mod tests { #[test] fn test_update_matching_exprs() -> Result<()> { + let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())); let exprs: Vec> = vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 3)), @@ -1403,7 +1404,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + Arc::clone(&udf), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1468,7 +1469,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + Arc::clone(&udf), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1522,6 +1523,7 @@ mod tests { #[test] fn test_update_projected_exprs() -> Result<()> { + let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())); let exprs: Vec> = vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 3)), @@ -1536,7 +1538,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + Arc::clone(&udf), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1601,7 +1603,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + Arc::clone(&udf), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 83d35c3d25b16..9f4c58a7fa31f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2503,6 +2503,51 @@ mod test { assert_eq!(udf.signature().volatility, Volatility::Volatile); } + #[test] + fn test_scalar_udf_eq_pointer() { + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } + + let udf1 = ScalarUDF::new_from_impl(DummyUDF::new()); + let udf1_clone = udf1.clone(); + let udf2 = ScalarUDF::new_from_impl(DummyUDF::new()); + + assert!(udf1.eq(&udf1_clone)); + assert!(!udf1.eq(&udf2)); + } + use super::*; #[test] diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 1a5d50477b1c8..d1ca4d601cd05 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -61,7 +61,7 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.inner.equals(other.inner.as_ref()) + Arc::ptr_eq(&self.inner, &other.inner) } } @@ -678,9 +678,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// - symmetric: `a.equals(b)` implies `b.equals(a)`; /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. /// - /// By default, compares [`Self::name`] and [`Self::signature`]. + /// By default, checks for pointer equality. fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - self.name() == other.name() && self.signature() == other.signature() + let self_ptr = self as *const _ as *const (); + let other_ptr = other as *const _ as *const (); + std::ptr::eq(self_ptr, other_ptr) } /// Returns a hash value for this scalar UDF.