diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index e62b26da99cc..211d3c4d7552 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -18,12 +18,15 @@ //! [`StringAgg`] accumulator for the `string_agg` function use std::hash::Hash; -use std::mem::size_of_val; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use crate::array_agg::ArrayAgg; -use arrow::array::{ArrayRef, AsArray, BooleanArray, LargeStringArray}; +use arrow::array::{ + Array, ArrayAccessor, ArrayRef, AsArray, BooleanArray, GenericStringArray, + LargeStringArray, StringArrayType, StringViewArray, +}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{ @@ -323,47 +326,297 @@ fn filter_index(values: &[T], index: usize) -> Vec { struct StringAggGroupsAccumulator { /// The delimiter placed between concatenated values. delimiter: String, - /// Accumulated string per group. `None` means no values have been seen - /// (the group's output will be NULL). - /// A potential improvement is to avoid this String allocation - /// See + /// Materialized string state for groups that use the eager fast path. values: Vec>, - /// Running total of string data bytes across all groups. + /// Running total of bytes stored in `values`. total_data_bytes: usize, + /// Deferred bookkeeping is allocated lazily after promotion. + deferred: Option, +} + +#[derive(Debug, Default)] +struct DeferredRows { + /// Source arrays retained from input batches or merged state batches. + batches: Vec, + /// Per-batch `(group_idx, row_idx)` pairs for non-null rows. + batch_entries: Vec>, +} + +enum StringInputArray<'a> { + Utf8(&'a GenericStringArray), + LargeUtf8(&'a GenericStringArray), + Utf8View(&'a StringViewArray), +} + +macro_rules! dispatch_string_input_array { + ($self:expr, $array:ident => $expr:expr) => { + match $self { + Self::Utf8($array) => $expr, + Self::LargeUtf8($array) => $expr, + Self::Utf8View($array) => $expr, + } + }; +} + +impl<'a> StringInputArray<'a> { + fn try_new(array: &'a ArrayRef) -> Result { + match array.data_type() { + DataType::Utf8 => Ok(Self::Utf8(array.as_string::())), + DataType::LargeUtf8 => Ok(Self::LargeUtf8(array.as_string::())), + DataType::Utf8View => Ok(Self::Utf8View(array.as_string_view())), + other => internal_err!("string_agg unexpected data type: {other}"), + } + } + + fn append_rows(&self, group_indices: &[usize]) -> Vec<(u32, u32)> { + dispatch_string_input_array!(self, array => { + StringAggGroupsAccumulator::append_rows_typed(array, group_indices) + }) + } + + fn append_batch_values( + &self, + values: &mut [Option], + entries: &[(u32, u32)], + delimiter: &str, + emit_groups: usize, + ) { + dispatch_string_input_array!(self, array => { + StringAggGroupsAccumulator::append_batch_values_typed( + values, + entries, + array, + delimiter, + emit_groups, + ) + }) + } } impl StringAggGroupsAccumulator { + const DEFER_GROUP_THRESHOLD: usize = 32; + const DEFER_PAYLOAD_LEN_THRESHOLD: usize = 32; + fn new(delimiter: String) -> Self { Self { delimiter, values: Vec::new(), total_data_bytes: 0, + deferred: None, + } + } + + fn clear_state(&mut self) { + // `size()` measures Vec capacity rather than len, so allocate new + // buffers instead of using `clear()`. + self.values = Vec::new(); + self.total_data_bytes = 0; + self.deferred = None; + } + + fn retain_after_emit(deferred: &mut DeferredRows, emit_groups: usize) { + let emit_groups = emit_groups as u32; + let mut retained_batches = Vec::with_capacity(deferred.batches.len()); + let mut retained_entries = Vec::with_capacity(deferred.batch_entries.len()); + + for (batch, entries) in deferred + .batches + .drain(..) + .zip(deferred.batch_entries.drain(..)) + { + let entries: Vec<_> = entries + .into_iter() + .filter_map(|(group_idx, row_idx)| { + if group_idx >= emit_groups { + Some((group_idx - emit_groups, row_idx)) + } else { + None + } + }) + .collect(); + if entries.is_empty() { + continue; + } + + // Keep the original arrays for this prototype and only renumber + // retained groups. + // todo: compact mixed batches so + // partially emitted batches no longer pin their full inputs. + retained_batches.push(batch); + retained_entries.push(entries); + } + + deferred.batches = retained_batches; + deferred.batch_entries = retained_entries; + } + + fn append_rows_typed<'a, A>(array: &A, group_indices: &[usize]) -> Vec<(u32, u32)> + where + A: StringArrayType<'a>, + { + array + .iter() + .zip(group_indices.iter()) + .enumerate() + .filter_map(|(row_idx, (opt_value, &group_idx))| { + opt_value.map(|_| (group_idx as u32, row_idx as u32)) + }) + .collect() + } + + fn append_group_value( + values: &mut [Option], + group_idx: usize, + value: &str, + delimiter: &str, + ) -> usize { + match &mut values[group_idx] { + Some(existing) => { + let added = delimiter.len() + value.len(); + existing.reserve(added); + existing.push_str(delimiter); + existing.push_str(value); + added + } + slot @ None => { + *slot = Some(value.to_string()); + value.len() + } } } - fn append_batch<'a>( + fn append_batch_typed<'a, I>( + values: &mut [Option], + iter: I, + group_indices: &[usize], + delimiter: &str, + ) -> usize + where + I: Iterator>, + { + iter.zip(group_indices.iter()) + .filter_map(|(opt_value, &group_idx)| { + opt_value.map(|value| { + Self::append_group_value(values, group_idx, value, delimiter) + }) + }) + .sum() + } + + fn append_eager_batch( &mut self, - iter: impl Iterator>, + array: &ArrayRef, group_indices: &[usize], - ) { - for (opt_value, &group_idx) in iter.zip(group_indices.iter()) { - if let Some(value) = opt_value { - match &mut self.values[group_idx] { - Some(existing) => { - let added = self.delimiter.len() + value.len(); - existing.reserve(added); - existing.push_str(&self.delimiter); - existing.push_str(value); - self.total_data_bytes += added; - } - slot @ None => { - *slot = Some(value.to_string()); - self.total_data_bytes += value.len(); - } - } + ) -> Result<()> { + let added = match array.data_type() { + DataType::Utf8 => Self::append_batch_typed( + &mut self.values, + array.as_string::().iter(), + group_indices, + &self.delimiter, + ), + DataType::LargeUtf8 => Self::append_batch_typed( + &mut self.values, + array.as_string::().iter(), + group_indices, + &self.delimiter, + ), + DataType::Utf8View => Self::append_batch_typed( + &mut self.values, + array.as_string_view().iter(), + group_indices, + &self.delimiter, + ), + other => return internal_err!("string_agg unexpected data type: {other}"), + }; + self.total_data_bytes += added; + Ok(()) + } + + fn append_batch_values_typed<'a, A>( + values: &mut [Option], + entries: &[(u32, u32)], + array: &A, + delimiter: &str, + emit_groups: usize, + ) where + A: ArrayAccessor, + { + for &(group_idx, row_idx) in entries { + let group_idx = group_idx as usize; + if group_idx >= emit_groups { + continue; } + + let row_idx = row_idx as usize; + debug_assert!(!array.is_null(row_idx)); + let _ = Self::append_group_value( + values, + group_idx, + array.value(row_idx), + delimiter, + ); } } + + fn append_batch_values( + values: &mut [Option], + entries: &[(u32, u32)], + array: &ArrayRef, + delimiter: &str, + emit_groups: usize, + ) -> Result<()> { + StringInputArray::try_new(array)?.append_batch_values( + values, + entries, + delimiter, + emit_groups, + ); + Ok(()) + } + + fn estimated_payload_len(array: &ArrayRef) -> Option { + let non_null_rows = array.len().saturating_sub(array.null_count()); + if non_null_rows == 0 { + return None; + } + + match array.data_type() { + DataType::Utf8 => { + Some(array.as_string::().value_data().len() / non_null_rows) + } + DataType::LargeUtf8 => { + Some(array.as_string::().value_data().len() / non_null_rows) + } + DataType::Utf8View => Some( + array + .as_string_view() + .data_buffers() + .iter() + .map(|buffer| buffer.len()) + .sum::() + / non_null_rows, + ), + _ => None, + } + } + + fn should_promote(&self, array: &ArrayRef, total_num_groups: usize) -> bool { + total_num_groups >= Self::DEFER_GROUP_THRESHOLD + && Self::estimated_payload_len(array) + .is_some_and(|len| len >= Self::DEFER_PAYLOAD_LEN_THRESHOLD) + } + + fn defer_batch(&mut self, array: ArrayRef, group_indices: &[usize]) -> Result<()> { + let input = StringInputArray::try_new(&array)?; + let entries = input.append_rows(group_indices); + if !entries.is_empty() { + let deferred = self.deferred.get_or_insert_with(DeferredRows::default); + deferred.batches.push(array); + deferred.batch_entries.push(entries); + } + Ok(()) + } } impl GroupsAccumulator for StringAggGroupsAccumulator { @@ -376,37 +629,54 @@ impl GroupsAccumulator for StringAggGroupsAccumulator { ) -> Result<()> { self.values.resize(total_num_groups, None); let array = apply_filter_as_nulls(&values[0], opt_filter)?; - match array.data_type() { - DataType::Utf8 => { - self.append_batch(array.as_string::().iter(), group_indices) - } - DataType::LargeUtf8 => { - self.append_batch(array.as_string::().iter(), group_indices) - } - DataType::Utf8View => { - self.append_batch(array.as_string_view().iter(), group_indices) - } - other => { - return internal_err!("string_agg unexpected data type: {other}"); - } + + if self.deferred.is_some() || self.should_promote(&array, total_num_groups) { + self.defer_batch(array, group_indices)?; + } else { + self.append_eager_batch(&array, group_indices)?; } + Ok(()) } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let to_emit = emit_to.take_needed(&mut self.values); + let mut to_emit = emit_to.take_needed(&mut self.values); + let emit_groups = to_emit.len(); let emitted_bytes: usize = to_emit .iter() .filter_map(|opt| opt.as_ref().map(|s| s.len())) .sum(); self.total_data_bytes -= emitted_bytes; - let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit)); - Ok(result) + if let Some(deferred) = &self.deferred { + for (batch, entries) in deferred.batches.iter().zip(&deferred.batch_entries) { + Self::append_batch_values( + &mut to_emit, + entries, + batch, + &self.delimiter, + emit_groups, + )?; + } + } + + match emit_to { + EmitTo::All => self.clear_state(), + EmitTo::First(_) => { + if let Some(deferred) = &mut self.deferred { + Self::retain_after_emit(deferred, emit_groups); + if deferred.batches.is_empty() { + self.deferred = None; + } + } + } + } + + Ok(Arc::new(LargeStringArray::from(to_emit))) } fn state(&mut self, emit_to: EmitTo) -> Result> { - self.evaluate(emit_to).map(|arr| vec![arr]) + Ok(vec![self.evaluate(emit_to)?]) } fn merge_batch( @@ -441,6 +711,26 @@ impl GroupsAccumulator for StringAggGroupsAccumulator { fn size(&self) -> usize { self.total_data_bytes + self.values.capacity() * size_of::>() + + self + .deferred + .as_ref() + .map(|deferred| { + deferred + .batches + .iter() + .map(|arr| { + arr.to_data().get_slice_memory_size().unwrap_or_default() + }) + .sum::() + + deferred.batches.capacity() * size_of::() + + deferred + .batch_entries + .iter() + .map(|entries| entries.capacity() * size_of::<(u32, u32)>()) + .sum::() + + deferred.batch_entries.capacity() * size_of::>() + }) + .unwrap_or_default() + self.delimiter.capacity() + size_of_val(self) } @@ -952,4 +1242,57 @@ mod tests { ); Ok(()) } + + #[test] + fn groups_mixed_eager_and_deferred_batches() -> Result<()> { + let mut acc = make_groups_acc(","); + + let eager_values: ArrayRef = + Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d"])); + acc.update_batch(&[eager_values], &[0, 1, 0, 1], None, 40)?; + assert!(acc.deferred.is_none()); + + let deferred_values: ArrayRef = Arc::new(LargeStringArray::from(vec![ + "large0_abcdefghijklmnopqrstuvwxyzabcdef", + "large1_bcdefghijklmnopqrstuvwxyzabcdefg", + "large2_cdefghijklmnopqrstuvwxyzabcdefgh", + ])); + acc.update_batch(&[deferred_values], &[0, 1, 39], None, 40)?; + assert!(acc.deferred.is_some()); + + let result = evaluate_groups(&mut acc, EmitTo::First(2)); + assert_eq!( + result, + vec![ + Some("a,c,large0_abcdefghijklmnopqrstuvwxyzabcdef".to_string()), + Some("b,d,large1_bcdefghijklmnopqrstuvwxyzabcdefg".to_string()), + ] + ); + + let remaining = evaluate_groups(&mut acc, EmitTo::All); + let mut expected = vec![None; 38]; + expected[37] = Some("large2_cdefghijklmnopqrstuvwxyzabcdefgh".to_string()); + assert_eq!(remaining, expected); + Ok(()) + } + + #[test] + fn groups_short_payloads_do_not_promote_to_deferred() -> Result<()> { + let mut acc = make_groups_acc(","); + let values: ArrayRef = Arc::new(LargeStringArray::from(vec![ + "aaa", "bbb", "ccc", "ddd", "eee", "fff", + ])); + + acc.update_batch(&[values], &[0, 1, 39, 38, 0, 1], None, 40)?; + + assert!(acc.deferred.is_none()); + let result = evaluate_groups(&mut acc, EmitTo::All); + let mut expected = vec![None; 40]; + expected[0] = Some("aaa,eee".to_string()); + expected[1] = Some("bbb,fff".to_string()); + expected[38] = Some("ddd".to_string()); + expected[39] = Some("ccc".to_string()); + assert_eq!(result, expected); + Ok(()) + } }