diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8e6bf9205b2f..33387f5d7234 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -1419,8 +1419,9 @@ mod tests { use arrow::{ array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray}, + compute, compute::SortOptions, - datatypes::Schema, + datatypes::{Field, Schema}, }; use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; @@ -1879,6 +1880,59 @@ mod tests { Ok(()) } + #[test] + fn test_first_value_ordered_update_batch_picks_min_ordering() -> Result<()> { + // Verify ordered first_value chooses the row with the smallest ordering key. + let mut accumulator = new_first_value_accumulator()?; + + accumulator.update_batch(&[ + Arc::new(StringArray::from(vec![Some("c"), Some("a"), Some("b")])), + Arc::new(Int64Array::from(vec![3, 1, 2])), + ])?; + + assert_eq!(accumulator.evaluate()?, ScalarValue::from("a")); + Ok(()) + } + + #[test] + fn test_last_value_ordered_update_batch_picks_max_ordering() -> Result<()> { + // Verify ordered last_value chooses the row with the largest ordering key. + let mut accumulator = new_last_value_accumulator()?; + + accumulator.update_batch(&[ + Arc::new(StringArray::from(vec![Some("c"), Some("a"), Some("b")])), + Arc::new(Int64Array::from(vec![3, 1, 2])), + ])?; + + assert_eq!(accumulator.evaluate()?, ScalarValue::from("c")); + Ok(()) + } + + #[test] + fn test_first_value_ordered_merge_batch_prefers_earlier_state() -> Result<()> { + // Verify merge_batch keeps the earliest value across partial states. + let mut lhs = new_first_value_accumulator()?; + lhs.update_batch(&[ + Arc::new(StringArray::from(vec![Some("later")])), + Arc::new(Int64Array::from(vec![10])), + ])?; + let lhs_state = lhs.state()?; + + let mut rhs = new_first_value_accumulator()?; + rhs.update_batch(&[ + Arc::new(StringArray::from(vec![Some("earlier")])), + Arc::new(Int64Array::from(vec![1])), + ])?; + let rhs_state = rhs.state()?; + + let mut merged = new_first_value_accumulator()?; + let states = concat_states(&lhs_state, &rhs_state)?; + merged.merge_batch(&states)?; + + assert_eq!(merged.evaluate()?, ScalarValue::from("earlier")); + Ok(()) + } + #[test] fn test_last_value_merge_with_is_set_nulls() -> Result<()> { // Test data with corrupted is_set flag @@ -1925,4 +1979,69 @@ mod tests { Ok(()) } + + #[test] + fn test_last_value_ordered_merge_batch_prefers_later_state() -> Result<()> { + // Verify merge_batch keeps the latest value across partial states. + let mut lhs = new_last_value_accumulator()?; + lhs.update_batch(&[ + Arc::new(StringArray::from(vec![Some("earlier")])), + Arc::new(Int64Array::from(vec![1])), + ])?; + let lhs_state = lhs.state()?; + + let mut rhs = new_last_value_accumulator()?; + rhs.update_batch(&[ + Arc::new(StringArray::from(vec![Some("later")])), + Arc::new(Int64Array::from(vec![10])), + ])?; + let rhs_state = rhs.state()?; + + let mut merged = new_last_value_accumulator()?; + let states = concat_states(&lhs_state, &rhs_state)?; + merged.merge_batch(&states)?; + + assert_eq!(merged.evaluate()?, ScalarValue::from("later")); + Ok(()) + } + + fn new_first_value_accumulator() -> Result { + FirstValueAccumulator::try_new( + &DataType::Utf8, + &[DataType::Int64], + single_ordering()?, + false, + false, + ) + } + + fn new_last_value_accumulator() -> Result { + LastValueAccumulator::try_new( + &DataType::Utf8, + &[DataType::Int64], + single_ordering()?, + false, + false, + ) + } + + fn single_ordering() -> Result { + let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]); + let ordering_expr = col("ordering", &schema)?; + Ok(LexOrdering::new(vec![PhysicalSortExpr { + expr: ordering_expr, + options: SortOptions::default(), + }]) + .unwrap()) + } + + fn concat_states(lhs: &[ScalarValue], rhs: &[ScalarValue]) -> Result> { + assert_eq!(lhs.len(), rhs.len()); + (0..lhs.len()) + .map(|idx| { + compute::concat(&[&lhs[idx].to_array()?, &rhs[idx].to_array()?]) + .map_err(|e| arrow_datafusion_err!(e)) + }) + .collect() + } }