From 8ff95613f23b4f9b4996e9e9c130c185978da2f6 Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Tue, 2 Jun 2026 14:29:54 +0200 Subject: [PATCH 1/7] feat(cubestore): push sorted partial aggregate below merge, propagate LIMIT into it Rewrite the worker plan PartialAggregate(Sorted) -> SortPreservingMerge into SortPreservingMerge -> per-partition PartialAggregate(Sorted), so the merge carries reduced partial states instead of all raw rows. Only sorted streaming aggregates are pushed: they hold O(1) accumulators per partition, while a hash aggregate would multiply its O(num_groups) memory by the partition count. Because the merged stream now contains duplicate group keys from different partitions, a plain worker row limit could truncate a group's partial states and silently corrupt its total. add_limit_to_workers turns the worker limit into a per-partition group limit on the aggregate plus a widened row budget (limit * partitions) on the merge (TailLimit for the reverse case). InlineAggregateStream honors the group limit: the emit-early threshold becomes min(batch_size, remaining limit) instead of a hard batch_size (4096), only closed groups are emitted, and input reading stops once the limit is reached, so a downstream LIMIT short-circuits the scan. --- .../cubestore-sql-tests/src/tests.rs | 329 +++++++++++-- .../inline_aggregate_stream.rs | 46 +- .../src/queryplanner/inline_aggregate/mod.rs | 247 +++++++++- .../distributed_partial_aggregate.rs | 459 +++++++++++++++++- .../src/queryplanner/optimizations/mod.rs | 5 +- 5 files changed, 1045 insertions(+), 41 deletions(-) diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index bd14875b42504..8d210a131a2bf 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -233,6 +233,18 @@ pub fn sql_tests(prefix: &str) -> Vec<(&'static str, TestFn)> { "unique_key_and_multi_partitions_hash_aggregate", unique_key_and_multi_partitions_hash_aggregate, ), + t( + "group_by_prefix_sorted_aggregate_multi_partition", + group_by_prefix_sorted_aggregate_multi_partition, + ), + t( + "group_by_prefix_limit_high_cardinality", + group_by_prefix_limit_high_cardinality, + ), + t( + "planning_aggregate_below_merge_with_limit", + planning_aggregate_below_merge_with_limit, + ), t("divide_by_zero", divide_by_zero), t( "filter_multiple_in_for_decimal", @@ -384,6 +396,9 @@ lazy_static::lazy_static! { "create_table_with_csv_no_header", "create_table_with_csv_no_header_and_delimiter", "create_table_with_csv_no_header_and_quotes", + "group_by_prefix_sorted_aggregate_multi_partition", + "group_by_prefix_limit_high_cardinality", + "planning_aggregate_below_merge_with_limit", ].into_iter().map(ToOwned::to_owned).collect(); } @@ -3628,8 +3643,8 @@ async fn planning_simple(service: Box) -> Result<(), CubeError> { pp_phys_plan(p.worker.as_ref()), "InlineFinalAggregate\ \n Worker\ - \n InlinePartialAggregate\ - \n MergeSort\ + \n MergeSort\ + \n InlinePartialAggregate\ \n Union\ \n Scan, index: default:1:[1]:sort_on[id], fields: [id, amount]\ \n Sort\ @@ -7238,23 +7253,22 @@ async fn unique_key_and_multi_partitions(service: Box) -> Result< \n InlineFinalAggregate, partitions: 1\ \n MergeSort, partitions: 1\ \n Worker, partitions: 2\ - \n GlobalLimit, n: 100, partitions: 1\ - \n InlinePartialAggregate, partitions: 1\ - \n MergeSort, partitions: 1\ - \n Union, partitions: 2\ - \n Projection, [a, b], partitions: 1\ - \n LastRowByUniqueKey, partitions: 1\ - \n MergeSort, partitions: 1\ - \n Scan, index: default:1:[1]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 2\ - \n FilterByKeyRange, partitions: 1\ - \n MemoryScan, partitions: 1\ - \n FilterByKeyRange, partitions: 1\ - \n MemoryScan, partitions: 1\ - \n Projection, [a, b], partitions: 1\ - \n LastRowByUniqueKey, partitions: 1\ - \n Scan, index: default:2:[2]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 1\ + \n MergeSort, fetch: 200, partitions: 1\ + \n InlinePartialAggregate, limit: 100, partitions: 2\ + \n Union, partitions: 2\ + \n Projection, [a, b], partitions: 1\ + \n LastRowByUniqueKey, partitions: 1\ + \n MergeSort, partitions: 1\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 2\ \n FilterByKeyRange, partitions: 1\ - \n MemoryScan, partitions: 1"); + \n MemoryScan, partitions: 1\ + \n FilterByKeyRange, partitions: 1\ + \n MemoryScan, partitions: 1\ + \n Projection, [a, b], partitions: 1\ + \n LastRowByUniqueKey, partitions: 1\ + \n Scan, index: default:2:[2]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 1\ + \n FilterByKeyRange, partitions: 1\ + \n MemoryScan, partitions: 1"); } Ok(()) } @@ -7314,6 +7328,247 @@ async fn unique_key_and_multi_partitions_hash_aggregate( Ok(()) } +/// Correctness of the sorted partial aggregate executed per partition below the merge: group +/// keys present in several partitions must combine to the same totals as without the +/// optimization, with and without LIMIT. +async fn group_by_prefix_sorted_aggregate_multi_partition( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data1 (a int, b int, val int, fval double)") + .await?; + service + .exec_query("CREATE TABLE s.Data2 (a int, b int, val int, fval double)") + .await?; + + // Group keys with `a` in 4..8 get rows from both tables, i.e. from both partitions of the + // union. + let mut raw_rows = Vec::new(); + let mut values1 = Vec::new(); + for a in 0i64..8 { + for b in 0i64..3 { + for r in 0i64..2 { + let val = a * 31 + b * 17 + r; + values1.push(format!("({}, {}, {}, {}.5)", a, b, val, val)); + raw_rows.push((a, b, val)); + } + } + } + let mut values2 = Vec::new(); + for a in 4i64..12 { + for b in 0i64..3 { + for r in 0i64..2 { + let val = a * 13 + b * 7 + r; + values2.push(format!("({}, {}, {}, {}.5)", a, b, val, val)); + raw_rows.push((a, b, val)); + } + } + } + service + .exec_query(&format!( + "INSERT INTO s.Data1 (a, b, val, fval) VALUES {}", + values1.join(", ") + )) + .await?; + service + .exec_query(&format!( + "INSERT INTO s.Data2 (a, b, val, fval) VALUES {}", + values2.join(", ") + )) + .await?; + + // Expected aggregates per group; fval = val + 0.5 keeps float sums exactly representable, + // so the assertions stay byte-exact regardless of the summation order. + let mut groups = std::collections::BTreeMap::<(i64, i64), (i64, i64, i64, i64)>::new(); + for (a, b, val) in raw_rows { + let group = groups.entry((a, b)).or_insert((0, i64::MAX, i64::MIN, 0)); + group.0 += val; + group.1 = group.1.min(val); + group.2 = group.2.max(val); + group.3 += 1; + } + let group_row = |((a, b), (sum, min, max, count)): (&(i64, i64), &(i64, i64, i64, i64))| { + vec![ + TableValue::Int(*a), + TableValue::Int(*b), + TableValue::Int(*sum), + TableValue::Int(*min), + TableValue::Int(*max), + TableValue::Float((*sum as f64 + 0.5 * *count as f64).into()), + ] + }; + let expected: Vec> = groups.iter().map(group_row).collect(); + + let query = "SELECT a, b, sum(val), min(val), max(val), sum(fval) FROM (\ + SELECT * FROM s.Data1 UNION ALL SELECT * FROM s.Data2\ + ) `t` GROUP BY 1, 2"; + + let full = service + .exec_query(&format!("{} ORDER BY 1, 2", query)) + .await?; + assert_eq!(to_rows(&full), expected); + + // LIMIT must return the same prefix of the full result + let limited = service + .exec_query(&format!("{} ORDER BY 1, 2 LIMIT 5", query)) + .await?; + assert_eq!(to_rows(&limited), expected[..5]); + + // DESC ordering takes the tail groups + let tail = service + .exec_query(&format!("{} ORDER BY 1 DESC, 2 DESC LIMIT 5", query)) + .await?; + let expected_tail: Vec<_> = expected.iter().rev().take(5).cloned().collect(); + assert_eq!(to_rows(&tail), expected_tail); + + // ORDER BY an aggregate must not use the group-order shortcut + let by_sum = service + .exec_query(&format!("{} ORDER BY 3, 1, 2 LIMIT 4", query)) + .await?; + let mut by_sum_keys: Vec<(i64, i64, i64)> = groups + .iter() + .map(|((a, b), (sum, ..))| (*sum, *a, *b)) + .collect(); + by_sum_keys.sort(); + let expected_by_sum: Vec> = by_sum_keys + .iter() + .take(4) + .map(|(_, a, b)| group_row(((&(*a, *b)), &groups[&(*a, *b)]))) + .collect(); + assert_eq!(to_rows(&by_sum), expected_by_sum); + Ok(()) +} + +/// LIMIT short-circuit correctness on group counts below and above the execution batch size +/// (4096), with duplicate group keys across partitions. +async fn group_by_prefix_limit_high_cardinality( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data1 (a int, val int)") + .await?; + service + .exec_query("CREATE TABLE s.Data2 (a int, val int)") + .await?; + + // 7500 groups in total, above the 4096 execution batch size; keys 2500..5000 are present + // in both tables. + for chunk in (0i64..5000).collect::>().chunks(1000) { + let values = chunk + .iter() + .map(|a| format!("({}, {})", a, a * 3 + 1)) + .join(", "); + service + .exec_query(&format!("INSERT INTO s.Data1 (a, val) VALUES {}", values)) + .await?; + } + for chunk in (2500i64..7500).collect::>().chunks(1000) { + let values = chunk + .iter() + .map(|a| format!("({}, {})", a, a * 5 + 2)) + .join(", "); + service + .exec_query(&format!("INSERT INTO s.Data2 (a, val) VALUES {}", values)) + .await?; + } + + let expected = |range: std::ops::Range| -> Vec> { + range + .map(|a| { + let mut sum = 0; + if a < 5000 { + sum += a * 3 + 1; + } + if a >= 2500 { + sum += a * 5 + 2; + } + vec![TableValue::Int(a), TableValue::Int(sum)] + }) + .collect() + }; + + let query = "SELECT a, sum(val) FROM (\ + SELECT * FROM s.Data1 UNION ALL SELECT * FROM s.Data2\ + ) `t` GROUP BY 1"; + + // LIMIT far below the batch size + let r = service + .exec_query(&format!("{} ORDER BY 1 LIMIT 10", query)) + .await?; + assert_eq!(to_rows(&r), expected(0..10)); + + // LIMIT above the batch size + let r = service + .exec_query(&format!("{} ORDER BY 1 LIMIT 5000", query)) + .await?; + assert_eq!(to_rows(&r), expected(0..5000)); + + // No limit: all the groups + let r = service.exec_query(&format!("{} ORDER BY 1", query)).await?; + assert_eq!(to_rows(&r), expected(0..7500)); + Ok(()) +} + +/// The sorted partial aggregate runs per partition below the merge; the worker limit becomes a +/// group limit on the aggregate plus a widened row budget on the merge (duplicate group keys +/// from different partitions make a plain row limit incorrect). +async fn planning_aggregate_below_merge_with_limit( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Orders(a int, b int, amount int)") + .await?; + + let p = service + .plan_query( + "SELECT a, b, SUM(amount) FROM (\ + SELECT * FROM s.Orders UNION ALL SELECT * FROM s.Orders\ + ) `t` GROUP BY 1, 2", + ) + .await?; + assert_eq!( + pp_phys_plan(p.worker.as_ref()), + "InlineFinalAggregate\ + \n Worker\ + \n MergeSort\ + \n InlinePartialAggregate\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" + ); + + let p = service + .plan_query( + "SELECT a, b, SUM(amount) FROM (\ + SELECT * FROM s.Orders UNION ALL SELECT * FROM s.Orders\ + ) `t` GROUP BY 1, 2 ORDER BY 1, 2 LIMIT 5", + ) + .await?; + assert_eq!( + pp_phys_plan(p.worker.as_ref()), + "Sort, fetch: 5\ + \n InlineFinalAggregate\ + \n Worker\ + \n MergeSort, fetch: 10\ + \n InlinePartialAggregate, limit: 5\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" + ); + Ok(()) +} + async fn divide_by_zero(service: Box) -> Result<(), CubeError> { service.exec_query("CREATE SCHEMA s").await?; service.exec_query("CREATE TABLE s.t(i int, z int)").await?; @@ -8053,12 +8308,12 @@ async fn build_range_end(service: Box) -> Result<(), CubeError> { Ok(()) } -async fn assert_limit_pushdown_using_search_string( +async fn assert_limit_pushdown_using_search_strings( service: &Box, query: &str, expected_index: Option<&str>, is_limit_expected: bool, - search_string: &str, + search_strings: &[&str], ) -> Result, CubeError> { let res = service .exec_query(&format!("EXPLAIN ANALYZE {}", query)) @@ -8073,18 +8328,17 @@ async fn assert_limit_pushdown_using_search_string( ))); } } - let expected_limit = search_string; if is_limit_expected { - if !s.contains(expected_limit) { + if !search_strings.iter().any(|expected| s.contains(expected)) { return Err(CubeError::internal(format!( "{} expected but not found", - expected_limit + search_strings.join(" or ") ))); } - } else if s.contains(expected_limit) { + } else if let Some(found) = search_strings.iter().find(|e| s.contains(*e)) { return Err(CubeError::internal(format!( "{} unexpected but found", - expected_limit + found ))); } } @@ -8095,6 +8349,23 @@ async fn assert_limit_pushdown_using_search_string( Ok(res.get_rows().clone()) } +async fn assert_limit_pushdown_using_search_string( + service: &Box, + query: &str, + expected_index: Option<&str>, + is_limit_expected: bool, + search_string: &str, +) -> Result, CubeError> { + assert_limit_pushdown_using_search_strings( + service, + query, + expected_index, + is_limit_expected, + &[search_string], + ) + .await +} + async fn assert_limit_pushdown( service: &Box, query: &str, @@ -8102,15 +8373,17 @@ async fn assert_limit_pushdown( is_limit_expected: bool, is_tail_limit: bool, ) -> Result, CubeError> { - assert_limit_pushdown_using_search_string( + assert_limit_pushdown_using_search_strings( service, query, expected_index, is_limit_expected, if is_tail_limit { - "TailLimit" + &["TailLimit"] } else { - "GlobalLimit" + // The worker limit is either a plain row limit or, for a partial aggregate running + // per partition below the merge, a group limit on the aggregate. + &["GlobalLimit", "InlinePartialAggregate, limit:"] }, ) .await diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs index 5b2e6c4c38df1..a63d34c1d363e 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs @@ -43,6 +43,11 @@ pub(crate) struct InlineAggregateStream { batch_size: usize, + /// Per-partition limit on the number of emitted groups, see [`InlineAggregateExec::limit`] + limit: Option, + + groups_emitted: usize, + exec_state: ExecutionState, input_done: bool, @@ -100,6 +105,8 @@ impl InlineAggregateStream { group_by: agg_group_by, exec_state, batch_size, + limit: agg.limit(), + groups_emitted: 0, current_group_indices, group_values, input_done: false, @@ -175,6 +182,12 @@ impl Stream for InlineAggregateStream { loop { match &self.exec_state { ExecutionState::ReadingInput => { + // All needed groups are emitted, skip the rest of the input + if self.limit_reached() { + self.exec_state = ExecutionState::Done; + continue; + } + match ready!(self.input.poll_next_unpin(cx)) { // New input batch to aggregate Some(Ok(batch)) => { @@ -283,25 +296,38 @@ impl InlineAggregateStream { Ok(Some(batch)) } - /// Check if we have enough groups to emit a batch, keeping the last (potentially incomplete) group. - /// - /// For sorted aggregation, we emit batches of size batch_size when we have accumulated - /// more than batch_size groups. We always keep the last group as it may continue in the next input batch. - fn should_emit_early(&self) -> bool { - // Need at least (batch_size + 1) groups to emit batch_size and keep 1 - self.group_values.len() > self.batch_size + fn limit_reached(&self) -> bool { + self.limit.is_some_and(|limit| self.groups_emitted >= limit) + } + + /// How many groups to emit in the next early batch: full batches until the limit (if any) + /// leaves fewer groups to emit. + fn emit_early_threshold(&self) -> usize { + match self.limit { + Some(limit) => self + .batch_size + .min(limit.saturating_sub(self.groups_emitted)), + None => self.batch_size, + } } /// Emit a batch of groups if we have enough accumulated, keeping the last group. /// + /// For sorted aggregation, we emit when we have accumulated more than threshold groups: the + /// last group is always kept as it may continue in the next input batch, so only closed + /// groups are emitted. + /// /// Returns Some(batch) if emitted, None otherwise. fn emit_early_if_ready(&mut self) -> DFResult> { - if !self.should_emit_early() { + let threshold = self.emit_early_threshold(); + // Need at least (threshold + 1) groups to emit threshold closed groups and keep 1 + if threshold == 0 || self.group_values.len() <= threshold { return Ok(None); } - // Emit exactly batch_size groups, keeping the rest (including last incomplete group) - self.emit(EmitTo::First(self.batch_size)) + let batch = self.emit(EmitTo::First(threshold))?; + self.groups_emitted += threshold; + Ok(batch) } fn group_aggregate_batch(&mut self, batch: RecordBatch) -> DFResult<()> { diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs index e8ea319ec4605..57b2fd2e28355 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs @@ -40,7 +40,8 @@ pub struct InlineAggregateExec { aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, - /// Set if the output of this aggregation is truncated by a upstream sort/limit clause + /// Per-partition limit on the number of emitted groups: each partition emits at most this + /// many first (in group order) complete groups and stops reading its input limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, @@ -111,6 +112,14 @@ impl InlineAggregateExec { self.limit } + /// Returns a copy of this aggregate with the per-partition group limit set. Each partition + /// emits at most `limit` first (in group order) complete groups and stops reading its input. + pub fn with_limit(&self, limit: Option) -> Self { + let mut result = self.clone(); + result.limit = limit; + result + } + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -289,3 +298,239 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::BinaryView ) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int64Array, RecordBatch}; + use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::common::arrow::compute::concat_batches; + use datafusion::functions_aggregate::sum::sum_udaf; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::col; + use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion::physical_plan::collect; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::prelude::{SessionConfig, SessionContext}; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use futures::StreamExt; + use std::sync::atomic::{AtomicUsize, Ordering}; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])) + } + + fn make_batch(schema: &SchemaRef, rows: &[(i64, i64)]) -> RecordBatch { + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.0))), + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.1))), + ], + ) + .unwrap() + } + + fn sorted_source( + schema: &SchemaRef, + partitions: Vec>, + ) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", schema).unwrap(), + )]); + let source = MemorySourceConfig::try_new(&partitions, schema.clone(), None) + .unwrap() + .try_with_sort_information(vec![ordering]) + .unwrap(); + Arc::new(DataSourceExec::new(Arc::new(source))) + } + + fn partial_sum_inline_aggregate( + input: Arc, + limit: Option, + ) -> Arc { + let schema = input.schema(); + let group_by = + PhysicalGroupBy::new_single(vec![(col("k", &schema).unwrap(), "k".to_string())]); + let sum = AggregateExprBuilder::new(sum_udaf(), vec![col("v", &schema).unwrap()]) + .schema(schema.clone()) + .alias("sum_v") + .build() + .unwrap(); + let agg = AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![Arc::new(sum)], + vec![None], + input, + schema, + ) + .unwrap(); + assert!( + matches!(agg.input_order_mode(), InputOrderMode::Sorted), + "test setup must produce a sorted aggregate" + ); + let inline = InlineAggregateExec::try_new_from_aggregate(&agg).unwrap(); + Arc::new(inline.with_limit(limit)) + } + + fn run(plan: Arc, batch_size: usize) -> Vec<(i64, i64)> { + let session = + SessionContext::new_with_config(SessionConfig::new().with_batch_size(batch_size)); + let batches = futures::executor::block_on(collect(plan, session.task_ctx())).unwrap(); + let schema = batches[0].schema(); + let batch = concat_batches(&schema, &batches).unwrap(); + let keys = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let sums = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + keys.iter() + .zip(sums.iter()) + .map(|(k, s)| (k.unwrap(), s.unwrap())) + .collect() + } + + /// A group continuing in the next input batch must not be emitted early with a partial sum. + #[test] + fn limit_emits_only_closed_groups() { + let schema = test_schema(); + let input = sorted_source( + &schema, + vec![vec![ + make_batch(&schema, &[(1, 10), (2, 20), (3, 30), (3, 31)]), + make_batch(&schema, &[(3, 32), (4, 40)]), + ]], + ); + let agg = partial_sum_inline_aggregate(input, Some(3)); + assert_eq!(run(agg, 4096), vec![(1, 10), (2, 20), (3, 93)]); + } + + #[test] + fn limit_results_match_no_limit_prefix() { + let schema = test_schema(); + let rows: Vec<(i64, i64)> = (0..1000).map(|i| (i / 3, i)).collect(); + let batches: Vec = rows.chunks(97).map(|c| make_batch(&schema, c)).collect(); + let source = sorted_source(&schema, vec![batches]); + + let no_limit = run(partial_sum_inline_aggregate(source.clone(), None), 4096); + let limited = run(partial_sum_inline_aggregate(source, Some(5)), 4096); + assert_eq!(limited, no_limit[..5]); + } + + #[test] + fn limit_larger_than_group_count_emits_all() { + let schema = test_schema(); + let input = sorted_source( + &schema, + vec![vec![make_batch(&schema, &[(1, 10), (2, 20), (3, 30)])]], + ); + let agg = partial_sum_inline_aggregate(input, Some(100)); + assert_eq!(run(agg, 4096), vec![(1, 10), (2, 20), (3, 30)]); + } + + /// Emitting in batch_size chunks until the limit is reached. + #[test] + fn limit_above_batch_size_emits_incrementally() { + let schema = test_schema(); + let rows: Vec<(i64, i64)> = (0..16).map(|i| (i / 2, i)).collect(); + let batches: Vec = rows.chunks(3).map(|c| make_batch(&schema, c)).collect(); + let source = sorted_source(&schema, vec![batches]); + + let no_limit = run(partial_sum_inline_aggregate(source.clone(), None), 2); + let limited = run(partial_sum_inline_aggregate(source, Some(5)), 2); + assert_eq!(limited, no_limit[..5]); + } + + /// Wraps a plan and counts batches its streams produce. + #[derive(Debug)] + struct CountingExec { + inner: Arc, + batches_polled: Arc, + } + + impl DisplayAs for CountingExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CountingExec") + } + } + + impl ExecutionPlan for CountingExec { + fn name(&self) -> &'static str { + "CountingExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(CountingExec { + inner: children[0].clone(), + batches_polled: self.batches_polled.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let stream = self.inner.execute(partition, context)?; + let counter = self.batches_polled.clone(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + stream.schema(), + stream.inspect(move |_| { + counter.fetch_add(1, Ordering::SeqCst); + }), + ))) + } + } + + /// Once the limit is reached the aggregate must stop reading its input, so a downstream + /// LIMIT short-circuits the scan. + #[test] + fn limit_stops_reading_input() { + let schema = test_schema(); + let batches: Vec = (0..100) + .map(|i| { + let rows: Vec<(i64, i64)> = (0..10).map(|j| (i * 10 + j, 1)).collect(); + make_batch(&schema, &rows) + }) + .collect(); + let source = sorted_source(&schema, vec![batches]); + let batches_polled = Arc::new(AtomicUsize::new(0)); + let counting = Arc::new(CountingExec { + inner: source, + batches_polled: batches_polled.clone(), + }); + + let result = run(partial_sum_inline_aggregate(counting, Some(5)), 4096); + assert_eq!(result.len(), 5); + assert!( + batches_polled.load(Ordering::SeqCst) < 10, + "aggregate must stop polling input after the limit is reached, polled {} batches", + batches_polled.load(Ordering::SeqCst) + ); + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index e670d4be6e945..a1e1cb075ccc9 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -1,4 +1,5 @@ use crate::cluster::WorkerPlanningParams; +use crate::queryplanner::inline_aggregate::{InlineAggregateExec, InlineAggregateMode}; use crate::queryplanner::planning::WorkerExec; use crate::queryplanner::query_executor::ClusterSendExec; use crate::queryplanner::tail_limit::TailLimitExec; @@ -17,7 +18,9 @@ use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, PhysicalExpr}; +use datafusion::physical_plan::{ + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, +}; use itertools::Itertools as _; use std::collections::HashSet; use std::sync::Arc; @@ -120,6 +123,74 @@ pub fn push_aggregate_to_workers( )?)) } +/// Transforms from: +/// AggregatePartial, Sorted +/// `- SortPreservingMerge +/// `- source(N partitions) +/// to: +/// SortPreservingMerge +/// `- AggregatePartial, Sorted (executed per partition) +/// `- source(N partitions) +/// +/// The merge then carries one row per group per partition instead of all raw rows. Duplicate +/// group keys from different partitions are adjacent in the merged stream and get combined by +/// the Final aggregate, the same way partial states from different workers are. +/// +/// Only sorted (streaming) partial aggregates are pushed: they hold O(1) accumulators per +/// partition, while a hash aggregate holds O(num_groups) and would multiply its memory usage by +/// the partition count. +pub fn push_sorted_partial_aggregate_below_merge( + p: Arc, +) -> Result, DataFusionError> { + let Some(agg) = p.as_any().downcast_ref::() else { + return Ok(p); + }; + if *agg.mode() != AggregateMode::Partial + || !matches!(agg.input_order_mode(), InputOrderMode::Sorted) + // Restrict to aggregates convertible to InlineAggregateExec: `add_limit_to_workers` + // relies on every merge-above-partial-aggregate pair being an InlineAggregateExec to + // apply row limits without truncating duplicate group keys. + || !agg.group_expr().is_single() + { + return Ok(p); + } + let Some(merge) = agg + .input() + .as_any() + .downcast_ref::() + else { + return Ok(p); + }; + if merge.fetch().is_some() { + return Ok(p); + } + let merge_input = merge.input(); + if merge_input.output_partitioning().partition_count() <= 1 { + return Ok(p); + } + + let new_agg = AggregateExec::try_new( + AggregateMode::Partial, + agg.group_expr().clone(), + agg.aggr_expr().to_vec(), + agg.filter_expr().to_vec(), + merge_input.clone(), + agg.input_schema(), + )?; + // Per-partition input must still be sorted on the group keys, otherwise the aggregate + // becomes hash-based and must stay above the merge. + if !matches!(new_agg.input_order_mode(), InputOrderMode::Sorted) { + return Ok(p); + } + let Some(ordering) = new_agg.properties().output_ordering().cloned() else { + return Ok(p); + }; + Ok(Arc::new(SortPreservingMergeExec::new( + ordering, + Arc::new(new_agg), + ))) +} + pub fn ensure_partition_merge_helper( p: Arc, new_child: &mut bool, @@ -224,6 +295,49 @@ pub fn add_limit_to_workers( let Some((limit, reverse)) = limit_and_reverse else { return Ok(p); }; + + // The merged per-partition partial aggregate stream may contain duplicate group keys from + // different partitions, and a plain row limit could cut off part of some group's partial + // states, silently corrupting that group's total. Limit groups per partition instead, and + // widen the row budget to limit * partitions rows: that is guaranteed to cover all rows of + // the first (last, for reverse) `limit` complete groups. + if let Some(merge) = input.as_any().downcast_ref::() { + if let Some(agg) = merge.input().as_any().downcast_ref::() { + if *agg.mode() == InlineAggregateMode::Partial { + let partitions = agg.properties().output_partitioning().partition_count(); + let row_budget = limit.saturating_mul(partitions); + let new_input: Arc = if reverse { + // The last groups are unknown until the input is exhausted, so the + // aggregates can't stop early; only widen the row limit. + Arc::new(TailLimitExec::new(input.clone(), row_budget)) + } else { + let agg_limit = agg.limit().map_or(limit, |l| l.min(limit)); + let new_agg = Arc::new(agg.with_limit(Some(agg_limit))); + Arc::new( + SortPreservingMergeExec::new(merge.expr().clone(), new_agg) + .with_fetch(Some(row_budget)), + ) + }; + return p.with_new_children(vec![new_input]); + } + } + } + + // A single-partition sorted partial aggregate emits one row per group, so a row limit is + // exact; pass it into the aggregate so it can stop reading its input early. + if !reverse { + if let Some(agg) = input.as_any().downcast_ref::() { + if *agg.mode() == InlineAggregateMode::Partial + && agg.properties().output_partitioning().partition_count() == 1 + { + let agg_limit = agg.limit().map_or(limit, |l| l.min(limit)); + let new_agg = Arc::new(agg.with_limit(Some(agg_limit))); + let limit_node = Arc::new(GlobalLimitExec::new(new_agg, 0, Some(limit))); + return p.with_new_children(vec![limit_node]); + } + } + } + if reverse { let limit = Arc::new(TailLimitExec::new(input.clone(), limit)); p.with_new_children(vec![limit]) @@ -435,3 +549,346 @@ fn replace_columns( .data, ) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int64Array, RecordBatch}; + use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::functions_aggregate::sum::sum_udaf; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::col; + use datafusion::physical_expr::PhysicalSortExpr; + use datafusion::physical_plan::aggregates::PhysicalGroupBy; + use datafusion::physical_plan::collect; + use datafusion::prelude::SessionContext; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use std::collections::BTreeMap; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])) + } + + fn make_batch(schema: &SchemaRef, rows: &[(i64, i64)]) -> RecordBatch { + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.0))), + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.1))), + ], + ) + .unwrap() + } + + /// Memory source with each partition sorted by `k`. + fn sorted_source( + schema: &SchemaRef, + partitions: Vec>, + ) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", schema).unwrap(), + )]); + let source = MemorySourceConfig::try_new(&partitions, schema.clone(), None) + .unwrap() + .try_with_sort_information(vec![ordering]) + .unwrap(); + Arc::new(DataSourceExec::new(Arc::new(source))) + } + + fn merge_by_k(input: Arc) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", &input.schema()).unwrap(), + )]); + Arc::new(SortPreservingMergeExec::new(ordering, input)) + } + + fn sum_aggregate( + mode: AggregateMode, + group_col: &str, + input: Arc, + ) -> Arc { + let schema = input.schema(); + let group_by = PhysicalGroupBy::new_single(vec![( + col(group_col, &schema).unwrap(), + group_col.to_string(), + )]); + let sum = AggregateExprBuilder::new(sum_udaf(), vec![col("v", &schema).unwrap()]) + .schema(schema.clone()) + .alias("sum_v") + .build() + .unwrap(); + Arc::new( + AggregateExec::try_new( + mode, + group_by, + vec![Arc::new(sum)], + vec![None], + input, + schema, + ) + .unwrap(), + ) + } + + /// Collects plan output into per-key sums, combining duplicate keys the way a Final + /// aggregate would. + fn collect_summed(plan: Arc) -> BTreeMap { + let session = SessionContext::new(); + let batches = futures::executor::block_on(collect(plan, session.task_ctx())).unwrap(); + let mut result = BTreeMap::new(); + for batch in batches { + let keys = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let sums = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + for (k, s) in keys.iter().zip(sums.iter()) { + *result.entry(k.unwrap()).or_insert(0) += s.unwrap(); + } + } + result + } + + fn two_partition_source(schema: &SchemaRef) -> Arc { + // Duplicate keys 2 and 3 across partitions + sorted_source( + schema, + vec![ + vec![make_batch(schema, &[(1, 10), (2, 20), (3, 30)])], + vec![make_batch(schema, &[(2, 21), (3, 31), (4, 40)])], + ], + ) + } + + #[test] + fn pushes_sorted_partial_aggregate_below_merge() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + + let merge = rewritten + .as_any() + .downcast_ref::() + .expect("merge must become the root"); + let agg = merge + .input() + .as_any() + .downcast_ref::() + .expect("partial aggregate must move below the merge"); + assert_eq!(*agg.mode(), AggregateMode::Partial); + assert!(matches!(agg.input_order_mode(), InputOrderMode::Sorted)); + assert_eq!( + agg.properties().output_partitioning().partition_count(), + 2, + "aggregate must run per partition" + ); + assert_eq!(rewritten.schema(), original.schema()); + + // Cross-partition duplicate keys combine to the same totals as in the original plan + assert_eq!(collect_summed(rewritten), collect_summed(original)); + } + + #[test] + fn does_not_push_hash_aggregate() { + let schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("g", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![5, 4])), + Arc::new(Int64Array::from(vec![10, 20])), + ], + ) + .unwrap(); + let source = sorted_source(&schema, vec![vec![batch.clone()], vec![batch]]); + // Grouping by `g` while the input is sorted by `k` makes the aggregate hash-based + let original = sum_aggregate(AggregateMode::Partial, "g", merge_by_k(source)); + assert!(matches!( + original + .as_any() + .downcast_ref::() + .unwrap() + .input_order_mode(), + InputOrderMode::Linear + )); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + #[test] + fn does_not_push_non_partial_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let original = sum_aggregate(AggregateMode::Single, "k", merge_by_k(source)); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + #[test] + fn does_not_push_below_merge_with_fetch() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let merge = Arc::new(merge_by_k(source).as_ref().clone().with_fetch(Some(3))); + let original = sum_aggregate(AggregateMode::Partial, "k", merge); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + #[test] + fn does_not_push_below_single_partition_merge() { + let schema = test_schema(); + let source = sorted_source( + &schema, + vec![vec![make_batch(&schema, &[(1, 10), (2, 20)])]], + ); + let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + fn inline(plan: Arc) -> Arc { + let agg = plan.as_any().downcast_ref::().unwrap(); + Arc::new(InlineAggregateExec::try_new_from_aggregate(agg).unwrap()) + } + + fn worker( + input: Arc, + limit_and_reverse: Option<(usize, bool)>, + ) -> Arc { + Arc::new(WorkerExec::new( + input, + 4096, + limit_and_reverse, + None, + WorkerPlanningParams { + worker_partition_count: 1, + }, + )) + } + + /// Worker plan with the partial aggregate below the merge: merge of per-partition partial + /// states can contain duplicate group keys, so the row limit must be limit * partitions + /// while the aggregates take the group limit. + #[test] + fn worker_limit_above_merged_partial_aggregate_limits_groups_per_partition() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let merged = push_sorted_partial_aggregate_below_merge_shape(agg); + let p = worker(merged, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker + .input + .as_any() + .downcast_ref::() + .expect("merge must stay on top of per-partition aggregates"); + assert_eq!( + merge.fetch(), + Some(6), + "row budget must be limit * partitions" + ); + let agg = merge + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(agg.limit(), Some(3)); + } + + #[test] + fn worker_reverse_limit_above_merged_partial_aggregate_widens_tail_limit() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let merged = push_sorted_partial_aggregate_below_merge_shape(agg); + let p = worker(merged, Some((3, true))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker = rewritten.as_any().downcast_ref::().unwrap(); + let tail = worker + .input + .as_any() + .downcast_ref::() + .expect("reverse limit must stay a tail limit"); + assert_eq!(tail.limit, 6, "row budget must be limit * partitions"); + let merge = tail + .input + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(merge.fetch(), None); + let agg = merge + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + agg.limit(), + None, + "tail limit can not stop the aggregate early" + ); + } + + /// Single partition partial aggregate emits unique group keys, the row limit stays exact + /// and also lets the aggregate stop early. + #[test] + fn worker_limit_above_single_partition_partial_aggregate_sets_aggregate_limit() { + let schema = test_schema(); + let source = sorted_source( + &schema, + vec![vec![make_batch(&schema, &[(1, 10), (2, 20)])]], + ); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let p = worker(agg, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker = rewritten.as_any().downcast_ref::().unwrap(); + let limit = worker + .input + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(limit.fetch(), Some(3)); + let agg = limit + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(agg.limit(), Some(3)); + } + + /// Builds the post-rewrite shape merge-above-aggregate for an already converted + /// InlineAggregateExec. + fn push_sorted_partial_aggregate_below_merge_shape( + agg: Arc, + ) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", &agg.schema()).unwrap(), + )]); + Arc::new(SortPreservingMergeExec::new(ordering, agg)) + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs index e4ee5eb698b3c..b828c6a8e6f53 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs @@ -10,7 +10,7 @@ use super::serialized_plan::PreSerializedPlan; use crate::cluster::{Cluster, WorkerPlanningParams}; use crate::queryplanner::optimizations::distributed_partial_aggregate::{ add_limit_to_workers, ensure_partition_merge, push_aggregate_to_workers, - replace_suboptimal_merge_sorts, + push_sorted_partial_aggregate_below_merge, replace_suboptimal_merge_sorts, }; use crate::queryplanner::optimizations::inline_aggregate_rewriter::replace_with_inline_aggregate; use crate::queryplanner::planning::CubeExtensionPlanner; @@ -145,6 +145,9 @@ fn pre_optimize_physical_plan( // Handles the root node case let p = ensure_partition_merge(p)?; + // Make the merge carry partial aggregate states instead of all raw rows + let p = rewrite_physical_plan(p, &mut |p| push_sorted_partial_aggregate_below_merge(p))?; + // Replace sorted AggregateExec with InlineAggregateExec for better performance let p = rewrite_physical_plan(p, &mut |p| replace_with_inline_aggregate(p))?; From a01cb84049e26f9287ab75b528f1eda3fdeaff7e Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Tue, 2 Jun 2026 14:53:00 +0200 Subject: [PATCH 2/7] feat(cubestore): per-partition TailLimit below merge, sliding tail window TailLimitStream collected its whole input to take the tail, materializing all worker rows for reverse limits; now it keeps a sliding window of trailing batches covering 'limit' rows, newer rows displace older ones. TailLimitExec returns the last 'limit' rows of each input partition instead of requiring a single one. The reverse worker limit above merged per-partition partial aggregates becomes a per-partition tail below the merge: within a partition the aggregate emits unique group keys, so 'limit' rows there are 'limit' complete groups and the merge carries at most 'limit' rows per partition instead of all groups with cross-partition duplicates. Groups beyond the last 'limit' may arrive with partial totals, but the router orders by the group key and its own limit drops them. --- .../cubestore-sql-tests/src/tests.rs | 34 +++++++ .../distributed_partial_aggregate.rs | 89 +++++++++++++------ .../cubestore/src/queryplanner/tail_limit.rs | 74 +++++++++++---- 3 files changed, 153 insertions(+), 44 deletions(-) diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index 8d210a131a2bf..dca2e0972528a 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -7505,6 +7505,14 @@ async fn group_by_prefix_limit_high_cardinality( .await?; assert_eq!(to_rows(&r), expected(0..5000)); + // DESC takes the last groups (per-partition tail path) + let r = service + .exec_query(&format!("{} ORDER BY 1 DESC LIMIT 10", query)) + .await?; + let mut expected_tail = expected(7490..7500); + expected_tail.reverse(); + assert_eq!(to_rows(&r), expected_tail); + // No limit: all the groups let r = service.exec_query(&format!("{} ORDER BY 1", query)).await?; assert_eq!(to_rows(&r), expected(0..7500)); @@ -7566,6 +7574,32 @@ async fn planning_aggregate_below_merge_with_limit( \n Sort\ \n Empty" ); + + // Reverse limit: a per-partition tail below the merge instead of a group limit (the last + // groups are unknown until the input ends, so the aggregate can't stop early) + let p = service + .plan_query( + "SELECT a, b, SUM(amount) FROM (\ + SELECT * FROM s.Orders UNION ALL SELECT * FROM s.Orders\ + ) `t` GROUP BY 1, 2 ORDER BY 1 DESC, 2 DESC LIMIT 5", + ) + .await?; + assert_eq!( + pp_phys_plan(p.worker.as_ref()), + "Sort, fetch: 5\ + \n InlineFinalAggregate\ + \n Worker\ + \n MergeSort\ + \n TailLimit, n: 5\ + \n InlinePartialAggregate\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" + ); Ok(()) } diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index a1e1cb075ccc9..4017d38e4cc52 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -297,25 +297,29 @@ pub fn add_limit_to_workers( }; // The merged per-partition partial aggregate stream may contain duplicate group keys from - // different partitions, and a plain row limit could cut off part of some group's partial - // states, silently corrupting that group's total. Limit groups per partition instead, and - // widen the row budget to limit * partitions rows: that is guaranteed to cover all rows of - // the first (last, for reverse) `limit` complete groups. + // different partitions, and a row limit above the merge could cut off part of some group's + // partial states, silently corrupting that group's total. Apply the limit per partition + // instead, below the merge: within a partition the aggregate emits unique group keys, so + // `limit` rows there are `limit` complete groups, and the union of per-partition first + // (last, for reverse) `limit` groups covers the global first (last) `limit` groups. Groups + // beyond that may arrive with partial totals, but the router orders by the group key, not + // by the totals, and its own limit drops them. if let Some(merge) = input.as_any().downcast_ref::() { if let Some(agg) = merge.input().as_any().downcast_ref::() { if *agg.mode() == InlineAggregateMode::Partial { - let partitions = agg.properties().output_partitioning().partition_count(); - let row_budget = limit.saturating_mul(partitions); let new_input: Arc = if reverse { - // The last groups are unknown until the input is exhausted, so the - // aggregates can't stop early; only widen the row limit. - Arc::new(TailLimitExec::new(input.clone(), row_budget)) + // The last groups are unknown until a partition is exhausted, so the + // aggregate can't stop early; a per-partition tail keeps the merge input + // and the tail window at `limit` rows per partition. + let tail = Arc::new(TailLimitExec::new(merge.input().clone(), limit)); + Arc::new(SortPreservingMergeExec::new(merge.expr().clone(), tail)) } else { + let partitions = agg.properties().output_partitioning().partition_count(); let agg_limit = agg.limit().map_or(limit, |l| l.min(limit)); let new_agg = Arc::new(agg.with_limit(Some(agg_limit))); Arc::new( SortPreservingMergeExec::new(merge.expr().clone(), new_agg) - .with_fetch(Some(row_budget)), + .with_fetch(Some(limit.saturating_mul(partitions))), ) }; return p.with_new_children(vec![new_input]); @@ -636,9 +640,9 @@ mod tests { /// Collects plan output into per-key sums, combining duplicate keys the way a Final /// aggregate would. - fn collect_summed(plan: Arc) -> BTreeMap { + async fn collect_summed(plan: Arc) -> BTreeMap { let session = SessionContext::new(); - let batches = futures::executor::block_on(collect(plan, session.task_ctx())).unwrap(); + let batches = collect(plan, session.task_ctx()).await.unwrap(); let mut result = BTreeMap::new(); for batch in batches { let keys = batch @@ -669,8 +673,8 @@ mod tests { ) } - #[test] - fn pushes_sorted_partial_aggregate_below_merge() { + #[tokio::test(flavor = "multi_thread")] + async fn pushes_sorted_partial_aggregate_below_merge() { let schema = test_schema(); let source = two_partition_source(&schema); let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); @@ -696,7 +700,10 @@ mod tests { assert_eq!(rewritten.schema(), original.schema()); // Cross-partition duplicate keys combine to the same totals as in the original plan - assert_eq!(collect_summed(rewritten), collect_summed(original)); + assert_eq!( + collect_summed(rewritten).await, + collect_summed(original).await + ); } #[test] @@ -817,10 +824,31 @@ mod tests { assert_eq!(agg.limit(), Some(3)); } - #[test] - fn worker_reverse_limit_above_merged_partial_aggregate_widens_tail_limit() { + #[tokio::test(flavor = "multi_thread")] + async fn worker_reverse_limit_above_merged_partial_aggregate_tails_each_partition() { let schema = test_schema(); - let source = two_partition_source(&schema); + // Keys 4..=6 are present in both partitions; per-partition tails of 3 groups are + // {4, 5, 6} and {7, 8, 9}. + let source = sorted_source( + &schema, + vec![ + vec![make_batch( + &schema, + &(1..=6).map(|k| (k, k * 10)).collect::>(), + )], + vec![make_batch( + &schema, + &(4..=9).map(|k| (k, k * 100 + 1)).collect::>(), + )], + ], + ); + let baseline = collect_summed(sum_aggregate( + AggregateMode::Partial, + "k", + merge_by_k(source.clone()), + )) + .await; + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); let merged = push_sorted_partial_aggregate_below_merge_shape(agg); let p = worker(merged, Some((3, true))); @@ -828,21 +856,21 @@ mod tests { let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); let worker = rewritten.as_any().downcast_ref::().unwrap(); - let tail = worker - .input - .as_any() - .downcast_ref::() - .expect("reverse limit must stay a tail limit"); - assert_eq!(tail.limit, 6, "row budget must be limit * partitions"); - let merge = tail + let merge = worker .input .as_any() .downcast_ref::() - .unwrap(); + .expect("merge must stay on top of per-partition tails"); assert_eq!(merge.fetch(), None); - let agg = merge + let tail = merge .input() .as_any() + .downcast_ref::() + .expect("reverse limit must become a per-partition tail below the merge"); + assert_eq!(tail.limit, 3); + let agg = tail + .input + .as_any() .downcast_ref::() .unwrap(); assert_eq!( @@ -850,6 +878,13 @@ mod tests { None, "tail limit can not stop the aggregate early" ); + + // The merged stream must carry complete totals for the last 3 group keys; earlier keys + // may arrive partial and are dropped by the router's own limit. + let summed = collect_summed(worker.input.clone()).await; + for key in [7, 8, 9] { + assert_eq!(summed[&key], baseline[&key], "complete total for key {key}"); + } } /// Single partition partial aggregate emits unique group keys, the row limit stays exact diff --git a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs index 4f64a28a45d83..fbc188056672d 100644 --- a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs +++ b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs @@ -6,21 +6,21 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::cube_ext; use datafusion::error::DataFusionError; use datafusion::execution::TaskContext; -use datafusion::physical_plan::common::collect; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; use futures::stream::Stream; -use futures::Future; +use futures::{Future, StreamExt}; use pin_project_lite::pin_project; use std::any::Any; +use std::collections::VecDeque; use std::fmt::Formatter; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -///Return n last rows in input +/// Returns the last `limit` rows of each input partition. #[derive(Debug)] pub struct TailLimitExec { pub input: Arc, @@ -74,19 +74,6 @@ impl ExecutionPlan for TailLimitExec { partition: usize, context: Arc, ) -> Result { - if 0 != partition { - return Err(DataFusionError::Internal(format!( - "TailLimitExec invalid partition {}", - partition - ))); - } - - if 1 != self.input.properties().partitioning.partition_count() { - return Err(DataFusionError::Internal( - "TailLimitExec requires a single input partition".to_owned(), - )); - } - let input = self.input.execute(partition, context)?; Ok(Box::pin(TailLimitStream::new(input, self.limit))) } @@ -109,7 +96,7 @@ impl TailLimitStream { let schema = input.schema(); let task = async move { let schema = input.schema(); - let data = collect(input).await?; + let data = collect_tail_window(input, n).await?; batches_tail(data, n, schema.clone()) }; cube_ext::spawn_oneshot_with_catch_unwind(task, tx); @@ -123,6 +110,26 @@ impl TailLimitStream { } } +/// Collects a sliding tail window of the input: keeps only the trailing batches needed to cover +/// `limit` rows, newer rows displace older ones. The front batch may overshoot the window, it is +/// sliced later by [batches_tail]. +async fn collect_tail_window( + mut input: SendableRecordBatchStream, + limit: usize, +) -> Result, DataFusionError> { + let mut window = VecDeque::new(); + let mut total_rows = 0; + while let Some(batch) = input.next().await { + let batch = batch?; + total_rows += batch.num_rows(); + window.push_back(batch); + while window.len() > 1 && total_rows - window.front().unwrap().num_rows() >= limit { + total_rows -= window.pop_front().unwrap().num_rows(); + } + } + Ok(window.into()) +} + fn batches_tail( mut batches: Vec, limit: usize, @@ -293,6 +300,39 @@ mod tests { assert!(to_ints(r).into_iter().flatten().collect_vec().is_empty()); } + #[tokio::test] + async fn empty_partition() { + let partitions = vec![vec![], vec![ints(vec![1, 2])]]; + let inp = try_make_memory_data_source(&partitions, ints_schema(), None).unwrap(); + let r = result_collect( + Arc::new(TailLimitExec::new(inp, 2)), + Arc::new(TaskContext::default()), + ) + .await + .unwrap(); + assert_eq!(to_ints(r).into_iter().flatten().collect_vec(), vec![1, 2]); + } + + #[tokio::test] + async fn multiple_partitions() { + let partitions = vec![ + vec![ints(vec![1, 2, 3]), ints(vec![4, 5])], + vec![ints(vec![10, 20])], + ]; + let inp = try_make_memory_data_source(&partitions, ints_schema(), None).unwrap(); + let r = result_collect( + Arc::new(TailLimitExec::new(inp, 2)), + Arc::new(TaskContext::default()), + ) + .await + .unwrap(); + // The last 2 rows of each partition + assert_eq!( + to_ints(r).into_iter().flatten().collect_vec(), + vec![4, 5, 10, 20], + ); + } + #[tokio::test] async fn several_batches() { let input = vec![ From 7535ad301d35009bf708ad6ad757ca9644c87224 Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Tue, 2 Jun 2026 17:09:33 +0200 Subject: [PATCH 3/7] fix(cubestore): strict aggregate limit contract, limit descent to the first aggregate InlineAggregateStream could emit more groups than its limit: emit-early ran once per input batch while a single batch can bring an arbitrary group backlog, and the final emit was unclamped. Now the backlog is drained in emit threshold chunks at the top of the poll loop (also stops reading input as early as possible) and the final emit is clamped by the remaining limit. add_limit_to_workers is rewritten as a limit descent: probe from the worker input through sort preserving merges to the first aggregate and place a per-partition row limit directly above it (LocalLimit forward, TailLimit reverse, GlobalLimit for a single partition), additionally passing the limit into InlineAggregateExec for the early input stop. Within a partition the aggregate emits unique group keys, so the row limit cuts at group boundaries; correctness no longer depends on which pass produced the plan shape, and a merge of per-partition partial aggregates never gets a row limit above it -- the fetch = limit * partitions widening is gone. Plans without aggregation keep the plain top-level limit. TailLimitExec now declares maintains_input_order, matching its passthrough properties. --- .../cubestore-sql-tests/src/tests.rs | 50 ++--- .../inline_aggregate_stream.rs | 59 ++++-- .../src/queryplanner/inline_aggregate/mod.rs | 42 +++++ .../distributed_partial_aggregate.rs | 171 ++++++++++++------ .../cubestore/src/queryplanner/tail_limit.rs | 4 + 5 files changed, 230 insertions(+), 96 deletions(-) diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index dca2e0972528a..2a08c7665b981 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -7253,22 +7253,23 @@ async fn unique_key_and_multi_partitions(service: Box) -> Result< \n InlineFinalAggregate, partitions: 1\ \n MergeSort, partitions: 1\ \n Worker, partitions: 2\ - \n MergeSort, fetch: 200, partitions: 1\ - \n InlinePartialAggregate, limit: 100, partitions: 2\ - \n Union, partitions: 2\ - \n Projection, [a, b], partitions: 1\ - \n LastRowByUniqueKey, partitions: 1\ - \n MergeSort, partitions: 1\ - \n Scan, index: default:1:[1]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 2\ + \n MergeSort, partitions: 1\ + \n LocalLimit, n: 100, partitions: 2\ + \n InlinePartialAggregate, limit: 100, partitions: 2\ + \n Union, partitions: 2\ + \n Projection, [a, b], partitions: 1\ + \n LastRowByUniqueKey, partitions: 1\ + \n MergeSort, partitions: 1\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 2\ + \n FilterByKeyRange, partitions: 1\ + \n MemoryScan, partitions: 1\ + \n FilterByKeyRange, partitions: 1\ + \n MemoryScan, partitions: 1\ + \n Projection, [a, b], partitions: 1\ + \n LastRowByUniqueKey, partitions: 1\ + \n Scan, index: default:2:[2]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 1\ \n FilterByKeyRange, partitions: 1\ - \n MemoryScan, partitions: 1\ - \n FilterByKeyRange, partitions: 1\ - \n MemoryScan, partitions: 1\ - \n Projection, [a, b], partitions: 1\ - \n LastRowByUniqueKey, partitions: 1\ - \n Scan, index: default:2:[2]:sort_on[a, b], fields: [a, b, c, e, __seq], partitions: 1\ - \n FilterByKeyRange, partitions: 1\ - \n MemoryScan, partitions: 1"); + \n MemoryScan, partitions: 1"); } Ok(()) } @@ -7564,15 +7565,16 @@ async fn planning_aggregate_below_merge_with_limit( "Sort, fetch: 5\ \n InlineFinalAggregate\ \n Worker\ - \n MergeSort, fetch: 10\ - \n InlinePartialAggregate, limit: 5\ - \n Union\ - \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ - \n Sort\ - \n Empty\ - \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ - \n Sort\ - \n Empty" + \n MergeSort\ + \n LocalLimit, n: 5\ + \n InlinePartialAggregate, limit: 5\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" ); // Reverse limit: a per-partition tail below the merge instead of a group limit (the last diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs index a63d34c1d363e..c9252b88a3c88 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs @@ -188,26 +188,27 @@ impl Stream for InlineAggregateStream { continue; } + // Drain groups accumulated beyond the emit threshold (a single input batch + // can bring many) before reading further input + match self.emit_early_if_ready() { + Ok(Some(batch)) => { + self.exec_state = ExecutionState::ProducingOutput(batch); + continue; + } + Ok(None) => { + // Not enough groups, read further + } + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + } + match ready!(self.input.poll_next_unpin(cx)) { - // New input batch to aggregate + // New input batch to aggregate; emitting happens at the top of the loop Some(Ok(batch)) => { - // Aggregate the batch if let Err(e) = self.group_aggregate_batch(batch) { return Poll::Ready(Some(Err(e))); } - - // Try to emit a batch if we have enough groups - match self.emit_early_if_ready() { - Ok(Some(batch)) => { - self.exec_state = ExecutionState::ProducingOutput(batch); - } - Ok(None) => { - // Not enough groups yet, continue reading - } - Err(e) => { - return Poll::Ready(Some(Err(e))); - } - } } // Error from input stream @@ -215,11 +216,11 @@ impl Stream for InlineAggregateStream { return Poll::Ready(Some(Err(e))); } - // Input stream exhausted - emit all remaining groups + // Input stream exhausted - emit the remaining groups, up to the limit None => { self.input_done = true; - match self.emit(EmitTo::All) { + match self.emit_remaining() { Ok(Some(batch)) => { self.exec_state = ExecutionState::ProducingOutput(batch); } @@ -257,9 +258,6 @@ impl Stream for InlineAggregateStream { } impl InlineAggregateStream { - /// Emit groups based on EmitTo strategy. - /// - /// Returns None if there are no groups to emit. /// Emit groups based on EmitTo strategy. /// /// Returns None if there are no groups to emit. @@ -330,6 +328,27 @@ impl InlineAggregateStream { Ok(batch) } + /// Emit the groups left at the end of the input: all of them are closed at this point, but + /// no more than the limit allows. + fn emit_remaining(&mut self) -> DFResult> { + let len = self.group_values.len(); + let emit_count = match self.limit { + Some(limit) => len.min(limit.saturating_sub(self.groups_emitted)), + None => len, + }; + if emit_count == 0 { + return Ok(None); + } + let emit_to = if emit_count < len { + EmitTo::First(emit_count) + } else { + EmitTo::All + }; + let batch = self.emit(emit_to)?; + self.groups_emitted += emit_count; + Ok(batch) + } + fn group_aggregate_batch(&mut self, batch: RecordBatch) -> DFResult<()> { // Evaluate the grouping expressions let group_by_values = evaluate_group_by(&self.group_by, &batch)?; diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs index 57b2fd2e28355..a3a9ea50420cb 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs @@ -451,6 +451,20 @@ mod tests { assert_eq!(limited, no_limit[..5]); } + /// A single input batch can bring more groups than the limit; the stream must still emit + /// exactly `limit` groups, draining the backlog in emit threshold chunks. + #[test] + fn limit_holds_when_one_batch_overshoots_it() { + let schema = test_schema(); + let rows: Vec<(i64, i64)> = (0..20).map(|k| (k, k + 100)).collect(); + let input = sorted_source(&schema, vec![vec![make_batch(&schema, &rows)]]); + let agg = partial_sum_inline_aggregate(input, Some(10)); + assert_eq!( + run(agg, 4), + (0..10).map(|k| (k, k + 100)).collect::>() + ); + } + /// Wraps a plan and counts batches its streams produce. #[derive(Debug)] struct CountingExec { @@ -533,4 +547,32 @@ mod tests { batches_polled.load(Ordering::SeqCst) ); } + + /// When one batch brings enough groups to cover the limit, the stream must drain them + /// without reading further input. + #[test] + fn limit_drains_backlog_without_reading_input() { + let schema = test_schema(); + let batches: Vec = (0..100) + .map(|i| { + let rows: Vec<(i64, i64)> = (0..20).map(|j| (i * 20 + j, 1)).collect(); + make_batch(&schema, &rows) + }) + .collect(); + let source = sorted_source(&schema, vec![batches]); + let batches_polled = Arc::new(AtomicUsize::new(0)); + let counting = Arc::new(CountingExec { + inner: source, + batches_polled: batches_polled.clone(), + }); + + // The first batch alone brings 20 groups > limit 10 + let result = run(partial_sum_inline_aggregate(counting, Some(10)), 4); + assert_eq!(result.len(), 10); + assert_eq!( + batches_polled.load(Ordering::SeqCst), + 1, + "the first input batch covers the limit, no further reads needed" + ); + } } diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index 4017d38e4cc52..480940f10a2b5 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -13,7 +13,7 @@ use datafusion::physical_optimizer::limit_pushdown::LimitPushdown; use datafusion::physical_optimizer::PhysicalOptimizerRule as _; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion::physical_plan::limit::GlobalLimitExec; +use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -147,9 +147,7 @@ pub fn push_sorted_partial_aggregate_below_merge( }; if *agg.mode() != AggregateMode::Partial || !matches!(agg.input_order_mode(), InputOrderMode::Sorted) - // Restrict to aggregates convertible to InlineAggregateExec: `add_limit_to_workers` - // relies on every merge-above-partial-aggregate pair being an InlineAggregateExec to - // apply row limits without truncating duplicate group keys. + // Restrict to aggregates convertible to InlineAggregateExec (no grouping sets) || !agg.group_expr().is_single() { return Ok(p); @@ -296,59 +294,76 @@ pub fn add_limit_to_workers( return Ok(p); }; - // The merged per-partition partial aggregate stream may contain duplicate group keys from - // different partitions, and a row limit above the merge could cut off part of some group's - // partial states, silently corrupting that group's total. Apply the limit per partition - // instead, below the merge: within a partition the aggregate emits unique group keys, so - // `limit` rows there are `limit` complete groups, and the union of per-partition first - // (last, for reverse) `limit` groups covers the global first (last) `limit` groups. Groups - // beyond that may arrive with partial totals, but the router orders by the group key, not - // by the totals, and its own limit drops them. - if let Some(merge) = input.as_any().downcast_ref::() { - if let Some(agg) = merge.input().as_any().downcast_ref::() { - if *agg.mode() == InlineAggregateMode::Partial { - let new_input: Arc = if reverse { - // The last groups are unknown until a partition is exhausted, so the - // aggregate can't stop early; a per-partition tail keeps the merge input - // and the tail window at `limit` rows per partition. - let tail = Arc::new(TailLimitExec::new(merge.input().clone(), limit)); - Arc::new(SortPreservingMergeExec::new(merge.expr().clone(), tail)) - } else { - let partitions = agg.properties().output_partitioning().partition_count(); - let agg_limit = agg.limit().map_or(limit, |l| l.min(limit)); - let new_agg = Arc::new(agg.with_limit(Some(agg_limit))); - Arc::new( - SortPreservingMergeExec::new(merge.expr().clone(), new_agg) - .with_fetch(Some(limit.saturating_mul(partitions))), - ) - }; - return p.with_new_children(vec![new_input]); - } + // A row limit must not be placed above a merge of per-partition aggregates: the merged + // stream carries duplicate group keys from different partitions, and cutting it by rows + // could cut off part of some group's partial states, silently corrupting that group's + // total. Instead the limit descends through the merges and lands directly above the first + // aggregate, where one row is one complete (within its partition) group: the union of + // per-partition first (last, for reverse) `limit` groups covers the global first (last) + // `limit` groups, and groups beyond that arrive complete or not at all -- the router + // orders by the group key, not by the totals, and its own limit drops them. + if first_aggregate_below_merges(input).is_some() { + let new_input = limit_above_first_aggregate(input, limit, reverse); + return p.with_new_children(vec![new_input]); + } + + // No aggregation: plain rows, each one self-contained, a row limit is exact anywhere. + if reverse { + let limit = Arc::new(TailLimitExec::new(input.clone(), limit)); + p.with_new_children(vec![limit]) + } else { + let limit = Arc::new(GlobalLimitExec::new(input.clone(), 0, Some(limit))); + let limit_optimized = LimitPushdown::new().optimize(limit, config)?; + p.with_new_children(vec![limit_optimized]) + } +} + +/// The first aggregate reachable from `p` looking only through sort preserving merges. +fn first_aggregate_below_merges(mut p: &Arc) -> Option<&Arc> { + loop { + if let Some(merge) = p.as_any().downcast_ref::() { + p = merge.input(); + } else if p.as_any().is::() || p.as_any().is::() { + return Some(p); + } else { + return None; } } +} + +/// Rebuilds the chain of merges with a per-partition row limit placed directly above the first +/// aggregate. Must only be called when [first_aggregate_below_merges] found one. +fn limit_above_first_aggregate( + p: &Arc, + limit: usize, + reverse: bool, +) -> Arc { + if let Some(merge) = p.as_any().downcast_ref::() { + let child = limit_above_first_aggregate(merge.input(), limit, reverse); + return Arc::new( + SortPreservingMergeExec::new(merge.expr().clone(), child).with_fetch(merge.fetch()), + ); + } - // A single-partition sorted partial aggregate emits one row per group, so a row limit is - // exact; pass it into the aggregate so it can stop reading its input early. + // The sorted streaming aggregate can additionally stop reading its input early once it has + // emitted `limit` groups; for reverse the last groups are unknown until the input ends, so + // there is nothing to pass into the aggregate. + let mut node = p.clone(); if !reverse { - if let Some(agg) = input.as_any().downcast_ref::() { - if *agg.mode() == InlineAggregateMode::Partial - && agg.properties().output_partitioning().partition_count() == 1 - { + if let Some(agg) = p.as_any().downcast_ref::() { + if *agg.mode() == InlineAggregateMode::Partial { let agg_limit = agg.limit().map_or(limit, |l| l.min(limit)); - let new_agg = Arc::new(agg.with_limit(Some(agg_limit))); - let limit_node = Arc::new(GlobalLimitExec::new(new_agg, 0, Some(limit))); - return p.with_new_children(vec![limit_node]); + node = Arc::new(agg.with_limit(Some(agg_limit))); } } } if reverse { - let limit = Arc::new(TailLimitExec::new(input.clone(), limit)); - p.with_new_children(vec![limit]) + Arc::new(TailLimitExec::new(node, limit)) + } else if node.output_partitioning().partition_count() == 1 { + Arc::new(GlobalLimitExec::new(node, 0, Some(limit))) } else { - let limit = Arc::new(GlobalLimitExec::new(input.clone(), 0, Some(limit))); - let limit_optimized = LimitPushdown::new().optimize(limit, config)?; - p.with_new_children(vec![limit_optimized]) + Arc::new(LocalLimitExec::new(node, limit)) } } @@ -811,12 +826,14 @@ mod tests { .as_any() .downcast_ref::() .expect("merge must stay on top of per-partition aggregates"); - assert_eq!( - merge.fetch(), - Some(6), - "row budget must be limit * partitions" - ); - let agg = merge + assert_eq!(merge.fetch(), None); + let local_limit = merge + .input() + .as_any() + .downcast_ref::() + .expect("the limit must become a per-partition row limit above the aggregate"); + assert_eq!(local_limit.fetch(), 3); + let agg = local_limit .input() .as_any() .downcast_ref::() @@ -824,6 +841,56 @@ mod tests { assert_eq!(agg.limit(), Some(3)); } + /// The limit descends below the merge even for an aggregate that is not an + /// InlineAggregateExec: only the early-stop absorption is specific to it, the row limit + /// placement is not. + #[test] + fn worker_limit_above_merged_raw_partial_aggregate_stays_below_merge() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = sum_aggregate(AggregateMode::Partial, "k", source); + let merged = merge_by_k(agg); + let p = worker(merged, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker_node = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker_node + .input + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(merge.fetch(), None); + let local_limit = merge + .input() + .as_any() + .downcast_ref::() + .expect("a row limit above the merge would truncate duplicate group keys"); + assert_eq!(local_limit.fetch(), 3); + assert!(local_limit.input().as_any().is::()); + + let merged = merge_by_k(sum_aggregate( + AggregateMode::Partial, + "k", + two_partition_source(&schema), + )); + let p = worker(merged, Some((3, true))); + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + let worker_node = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker_node + .input + .as_any() + .downcast_ref::() + .unwrap(); + let tail = merge + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(tail.limit, 3); + assert!(tail.input.as_any().is::()); + } + #[tokio::test(flavor = "multi_thread")] async fn worker_reverse_limit_above_merged_partial_aggregate_tails_each_partition() { let schema = test_schema(); diff --git a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs index fbc188056672d..ab3e54c599c37 100644 --- a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs +++ b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs @@ -58,6 +58,10 @@ impl ExecutionPlan for TailLimitExec { vec![&self.input] } + fn maintains_input_order(&self) -> Vec { + vec![true] + } + fn with_new_children( self: Arc, children: Vec>, From 2b78a1960604a4da325ebf1711813e3a2f8d0c00 Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Tue, 2 Jun 2026 17:42:18 +0200 Subject: [PATCH 4/7] fix(cubestore): limit descent through projections, tripwire for unrecognized shapes The limit descent now looks through ProjectionExec: a row limit is a plain count and commutes with column renames. Without this, a projection between the worker and the merge would route the limit into the generic fallback, whose LimitPushdown descends through projections into merges and could place a row fetch above a duplicate-bearing merge of per-partition partial states. For shapes the descent doesn't recognize, the fallback now refuses to add a worker limit at all when the subtree contains a per-partition partial aggregate: skipping the limit is always correct, the router applies the real one. --- .../distributed_partial_aggregate.rs | 122 ++++++++++++++++-- 1 file changed, 113 insertions(+), 9 deletions(-) diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index 480940f10a2b5..351dd43fa94b5 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -303,10 +303,18 @@ pub fn add_limit_to_workers( // `limit` groups, and groups beyond that arrive complete or not at all -- the router // orders by the group key, not by the totals, and its own limit drops them. if first_aggregate_below_merges(input).is_some() { - let new_input = limit_above_first_aggregate(input, limit, reverse); + let new_input = limit_above_first_aggregate(input, limit, reverse)?; return p.with_new_children(vec![new_input]); } + // Tripwire for shapes the descent doesn't recognize: a per-partition partial aggregate in + // the subtree means a duplicate-bearing merge somewhere above it, and the row limit below + // (LimitPushdown descends through projections into merges) could land on it. The worker + // limit is an optimization, skipping it is always correct. + if contains_multi_partition_partial_aggregate(input) { + return Ok(p); + } + // No aggregation: plain rows, each one self-contained, a row limit is exact anywhere. if reverse { let limit = Arc::new(TailLimitExec::new(input.clone(), limit)); @@ -318,11 +326,14 @@ pub fn add_limit_to_workers( } } -/// The first aggregate reachable from `p` looking only through sort preserving merges. +/// The first aggregate reachable from `p` looking only through sort preserving merges and +/// projections. fn first_aggregate_below_merges(mut p: &Arc) -> Option<&Arc> { loop { if let Some(merge) = p.as_any().downcast_ref::() { p = merge.input(); + } else if let Some(projection) = p.as_any().downcast_ref::() { + p = projection.input(); } else if p.as_any().is::() || p.as_any().is::() { return Some(p); } else { @@ -331,18 +342,26 @@ fn first_aggregate_below_merges(mut p: &Arc) -> Option<&Arc, limit: usize, reverse: bool, -) -> Arc { +) -> Result, DataFusionError> { if let Some(merge) = p.as_any().downcast_ref::() { - let child = limit_above_first_aggregate(merge.input(), limit, reverse); - return Arc::new( + let child = limit_above_first_aggregate(merge.input(), limit, reverse)?; + return Ok(Arc::new( SortPreservingMergeExec::new(merge.expr().clone(), child).with_fetch(merge.fetch()), - ); + )); + } + if let Some(projection) = p.as_any().downcast_ref::() { + // The limit is a plain row count, it commutes with column renames + let child = limit_above_first_aggregate(projection.input(), limit, reverse)?; + return Ok(Arc::new(ProjectionExec::try_new( + projection.expr().to_vec(), + child, + )?)); } // The sorted streaming aggregate can additionally stop reading its input early once it has @@ -358,13 +377,31 @@ fn limit_above_first_aggregate( } } - if reverse { + Ok(if reverse { Arc::new(TailLimitExec::new(node, limit)) } else if node.output_partitioning().partition_count() == 1 { Arc::new(GlobalLimitExec::new(node, 0, Some(limit))) } else { Arc::new(LocalLimitExec::new(node, limit)) + }) +} + +/// True if the plan contains a partial aggregate executed per partition: its merged output +/// carries duplicate group keys, so a row limit above such a merge is not group-aligned. +fn contains_multi_partition_partial_aggregate(p: &Arc) -> bool { + let is_one = if let Some(agg) = p.as_any().downcast_ref::() { + *agg.mode() == InlineAggregateMode::Partial + } else if let Some(agg) = p.as_any().downcast_ref::() { + *agg.mode() == AggregateMode::Partial + } else { + false + }; + if is_one && p.output_partitioning().partition_count() > 1 { + return true; } + p.children() + .into_iter() + .any(contains_multi_partition_partial_aggregate) } /// Because we disable `EnforceDistribution`, and because we add `SortPreservingMergeExec` in @@ -954,6 +991,73 @@ mod tests { } } + /// The descent looks through projections: the row limit is a plain count and commutes with + /// column renames. + #[test] + fn worker_limit_descends_through_projection() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let merged = push_sorted_partial_aggregate_below_merge_shape(agg); + let projection = Arc::new( + ProjectionExec::try_new( + merged + .schema() + .fields() + .iter() + .map(|f| (col(f.name(), &merged.schema()).unwrap(), f.name().clone())) + .collect(), + merged, + ) + .unwrap(), + ); + let p = worker(projection, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker_node = rewritten.as_any().downcast_ref::().unwrap(); + let projection = worker_node + .input + .as_any() + .downcast_ref::() + .expect("projection must stay on top"); + let merge = projection + .input() + .as_any() + .downcast_ref::() + .unwrap(); + let local_limit = merge + .input() + .as_any() + .downcast_ref::() + .expect("the limit must descend through the projection to the aggregate"); + assert_eq!(local_limit.fetch(), 3); + let agg = local_limit + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(agg.limit(), Some(3)); + } + + /// A shape the descent doesn't recognize but containing per-partition partial aggregates + /// must not get a worker limit at all: a row limit could land above a duplicate-bearing + /// merge. + #[test] + fn worker_limit_skipped_for_unrecognized_shape_with_per_partition_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let coalesce: Arc = Arc::new(CoalescePartitionsExec::new(agg)); + let p = worker(coalesce, Some((3, false))); + + let rewritten = add_limit_to_workers(p.clone(), &ConfigOptions::default()).unwrap(); + assert!( + Arc::ptr_eq(&rewritten, &p), + "no limit must be added to an unrecognized shape with a per-partition aggregate" + ); + } + /// Single partition partial aggregate emits unique group keys, the row limit stays exact /// and also lets the aggregate stop early. #[test] From fe1ebe8b56703b2ea4cd071829202c7976ee5908 Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Tue, 2 Jun 2026 17:50:34 +0200 Subject: [PATCH 5/7] fix(cubestore): bound the tail limit window by the limit, not by batch size A batch covering the whole window replaces it on arrival, cut to the last 'limit' rows; smaller batches go through the usual suffix eviction. The window now never holds more than 2 * limit rows regardless of input batch sizes (previously a single large batch stayed in the window whole until the end of the input). --- .../cubestore/src/queryplanner/tail_limit.rs | 67 ++++++++++++++++++- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs index ab3e54c599c37..dfe940a101510 100644 --- a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs +++ b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs @@ -115,8 +115,10 @@ impl TailLimitStream { } /// Collects a sliding tail window of the input: keeps only the trailing batches needed to cover -/// `limit` rows, newer rows displace older ones. The front batch may overshoot the window, it is -/// sliced later by [batches_tail]. +/// `limit` rows, newer rows displace older ones. Every stored batch is cut to at most `limit` +/// rows on arrival and the eviction keeps the minimal covering suffix, so the window never holds +/// more than 2 * `limit` rows. The front batch may overshoot the window, it is sliced later by +/// [batches_tail]. async fn collect_tail_window( mut input: SendableRecordBatchStream, limit: usize, @@ -125,7 +127,19 @@ async fn collect_tail_window( let mut total_rows = 0; while let Some(batch) = input.next().await { let batch = batch?; - total_rows += batch.num_rows(); + let rows = batch.num_rows(); + if rows >= limit { + // The batch alone covers the whole window + window.clear(); + total_rows = limit; + window.push_back(if rows > limit { + skip_first_rows(&batch, rows - limit) + } else { + batch + }); + continue; + } + total_rows += rows; window.push_back(batch); while window.len() > 1 && total_rows - window.front().unwrap().num_rows() >= limit { total_rows -= window.pop_front().unwrap().num_rows(); @@ -304,6 +318,53 @@ mod tests { assert!(to_ints(r).into_iter().flatten().collect_vec().is_empty()); } + #[tokio::test] + async fn batches_larger_than_limit() { + // 20-row batch followed by a 3-row batch, limit 5: last 2 rows of the big batch + 3 + let big: Vec = (0..20).collect(); + let input = vec![ints(big), ints(vec![100, 101, 102])]; + let inp = try_make_memory_data_source(&vec![input], ints_schema(), None).unwrap(); + let r = result_collect( + Arc::new(TailLimitExec::new(inp, 5)), + Arc::new(TaskContext::default()), + ) + .await + .unwrap(); + assert_eq!( + to_ints(r).into_iter().flatten().collect_vec(), + vec![18, 19, 100, 101, 102], + ); + } + + /// The window must stay bounded by the limit, not by the largest input batch. + #[tokio::test] + async fn window_stays_bounded() { + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + + // A large batch first, then a trickle of small ones that never covers the limit + let mut batches = vec![ints((0..1000).collect())]; + batches.extend((0..10).map(|i| ints(vec![i]))); + + let stream = Box::pin(RecordBatchStreamAdapter::new( + ints_schema(), + futures::stream::iter(batches.into_iter().map(Ok)), + )); + let window = collect_tail_window(stream, 20).await.unwrap(); + + let window_rows: usize = window.iter().map(|b| b.num_rows()).sum(); + assert!( + window_rows <= 2 * 20, + "window holds {} rows, must stay within 2 * limit", + window_rows + ); + let result = batches_tail(window, 20, ints_schema()).unwrap(); + let last_20: Vec = (990..1000).chain(0..10).collect(); + assert_eq!( + to_ints(vec![result]).into_iter().flatten().collect_vec(), + last_20 + ); + } + #[tokio::test] async fn empty_partition() { let partitions = vec![vec![], vec![ints(vec![1, 2])]]; From 8962d63284655c56bc335e2265c2dac2ab23066e Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Tue, 2 Jun 2026 17:52:27 +0200 Subject: [PATCH 6/7] docs(cubestore): note intentionally unnarrowed statistics and the defensive threshold guard --- .../queryplanner/inline_aggregate/inline_aggregate_stream.rs | 4 +++- .../cubestore/src/queryplanner/inline_aggregate/mod.rs | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs index c9252b88a3c88..bee04149514cc 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs @@ -318,7 +318,9 @@ impl InlineAggregateStream { /// Returns Some(batch) if emitted, None otherwise. fn emit_early_if_ready(&mut self) -> DFResult> { let threshold = self.emit_early_threshold(); - // Need at least (threshold + 1) groups to emit threshold closed groups and keep 1 + // Need at least (threshold + 1) groups to emit threshold closed groups and keep 1. + // The threshold == 0 check is defensive: the poll loop checks limit_reached() before + // calling this, so the threshold is at least 1 there. if threshold == 0 || self.group_values.len() <= threshold { return Ok(None); } diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs index a3a9ea50420cb..15099c6a69d8e 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs @@ -114,6 +114,9 @@ impl InlineAggregateExec { /// Returns a copy of this aggregate with the per-partition group limit set. Each partition /// emits at most `limit` first (in group order) complete groups and stops reading its input. + /// + /// Plan properties and statistics are intentionally not narrowed by the limit: the limit is + /// set after all the optimizers that could consume them have run. pub fn with_limit(&self, limit: Option) -> Self { let mut result = self.clone(); result.limit = limit; From 21fb7d452b39b335a9c6c95d47d5f2779d1186d0 Mon Sep 17 00:00:00 2001 From: Aleksandr Romanenko Date: Wed, 3 Jun 2026 11:43:49 +0200 Subject: [PATCH 7/7] feat(cubestore): coalesce instead of sort merge under global aggregates A global (no GROUP BY) aggregate doesn't use the input ordering, but the scan still merges its partitions with a SortPreservingMergeExec when the index has a sort key. For such queries the sort key often comes from the equality filters (it is picked for index selection and partition pruning, not as an ordering requirement), so the filters make the merge keys constant and every merge comparison becomes a full-length tie across all chunks -- pure waste. The new pre-optimize pass replaces such merges under the aggregate with plain partition coalescing, descending through filters, projections and unions. Restricted to hash aggregates without group expressions (one accumulator set per partition even when later optimizations make them run per partition; a grouped hash aggregate would multiply its hash table by the partition count) and without ordering requirements (array_agg(ORDER BY) and other order-sensitive aggregates keep their merge). The deduplicating merge under LastRowByUniqueKey is out of the descent's reach and stays intact. --- .../cubestore-sql-tests/src/tests.rs | 110 +++++++++ .../distributed_partial_aggregate.rs | 227 ++++++++++++++++++ .../src/queryplanner/optimizations/mod.rs | 8 +- rust/cubestore/cubestore/src/sql/mod.rs | 15 +- 4 files changed, 350 insertions(+), 10 deletions(-) diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index 2a08c7665b981..b8186f252bce2 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -245,6 +245,14 @@ pub fn sql_tests(prefix: &str) -> Vec<(&'static str, TestFn)> { "planning_aggregate_below_merge_with_limit", planning_aggregate_below_merge_with_limit, ), + t( + "global_aggregate_no_chunk_merge", + global_aggregate_no_chunk_merge, + ), + t( + "global_aggregate_unique_key_keeps_merge", + global_aggregate_unique_key_keeps_merge, + ), t("divide_by_zero", divide_by_zero), t( "filter_multiple_in_for_decimal", @@ -399,6 +407,8 @@ lazy_static::lazy_static! { "group_by_prefix_sorted_aggregate_multi_partition", "group_by_prefix_limit_high_cardinality", "planning_aggregate_below_merge_with_limit", + "global_aggregate_no_chunk_merge", + "global_aggregate_unique_key_keeps_merge", ].into_iter().map(ToOwned::to_owned).collect(); } @@ -7605,6 +7615,106 @@ async fn planning_aggregate_below_merge_with_limit( Ok(()) } +/// A no-GROUP BY (hash) aggregate over a multi-chunk scan: the chunks must be coalesced, not +/// merge-sorted -- the aggregate doesn't use the order, and the equality filters on the sort key +/// prefix make every merge comparison a full-length tie. +async fn global_aggregate_no_chunk_merge(service: Box) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query( + "CREATE TABLE s.Batch (tenant_id int, deployment_id int, ts timestamp, hits int, hits_failed int)", + ) + .await?; + service + .exec_query( + "CREATE TABLE s.Stream (tenant_id int, deployment_id int, ts timestamp, hits int, hits_failed int)", + ) + .await?; + + // Several inserts = several chunks, like a streaming table + for i in 0..3 { + let values = (0..50) + .map(|j| { + format!( + "(25358, 5, '2026-06-03T0{}:00:{:02}.000', {}, {})", + i, + j, + j, + j % 3 + ) + }) + .join(", "); + service + .exec_query(&format!( + "INSERT INTO s.Stream (tenant_id, deployment_id, ts, hits, hits_failed) VALUES {}", + values + )) + .await?; + } + + let query = "SELECT sum(hits) hits, sum(hits_failed) hits_failed FROM (\ + SELECT * FROM s.Batch WHERE 1 = 0 \ + UNION ALL \ + SELECT * FROM s.Stream\ + ) `t` WHERE tenant_id = 25358 AND deployment_id = 5 \ + AND ts >= to_timestamp('2026-06-03T00:00:00.000') \ + AND ts <= to_timestamp('2026-06-03T23:59:59.999') \ + LIMIT 10000"; + + let r = service.exec_query(query).await?; + assert_eq!(to_rows(&r), rows(&[(3675, 147)])); + + let p = service.plan_query(query).await?; + let worker_plan = pp_phys_plan(p.worker.as_ref()); + assert!( + !worker_plan.contains("MergeSort"), + "hash aggregate must not merge-sort the chunks:\n{}", + worker_plan + ); + let coalesce_pos = worker_plan.find("CoalescePartitions"); + let aggregate_pos = worker_plan.find("LinearPartialAggregate"); + assert!( + coalesce_pos.is_some() && aggregate_pos.is_some() && coalesce_pos < aggregate_pos, + "the per-partition aggregation must be coalesced above:\n{}", + worker_plan + ); + Ok(()) +} + +/// A unique key table deduplicates row versions via LastRowByUniqueKey, which needs its input +/// merge-sorted to keep the versions of a key adjacent. That merge must survive the +/// no-ordering-needed rewrite under a global aggregate. +async fn global_aggregate_unique_key_keeps_merge( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Versions (a int, b int, val int) unique key (a, b)") + .await?; + + service + .exec_query("INSERT INTO s.Versions (a, b, val, __seq) VALUES (1, 1, 10, 1), (2, 2, 20, 2)") + .await?; + // A newer version of (1, 1) in another chunk + service + .exec_query("INSERT INTO s.Versions (a, b, val, __seq) VALUES (1, 1, 30, 3)") + .await?; + + let query = "SELECT sum(val) FROM s.Versions"; + let r = service.exec_query(query).await?; + // Only the last version of each key counts: 30 + 20, not 10 + 30 + 20 + assert_eq!(to_rows(&r), rows(&[(50)])); + + let p = service.plan_query(query).await?; + let worker_plan = pp_phys_plan(p.worker.as_ref()); + assert!( + worker_plan.contains("LastRowByUniqueKey") && worker_plan.contains("MergeSort"), + "the deduplicating merge must stay in place:\n{}", + worker_plan + ); + Ok(()) +} + async fn divide_by_zero(service: Box) -> Result<(), CubeError> { service.exec_query("CREATE SCHEMA s").await?; service.exec_query("CREATE TABLE s.t(i int, z int)").await?; diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index 351dd43fa94b5..8fbe92c6fa55f 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -13,6 +13,7 @@ use datafusion::physical_optimizer::limit_pushdown::LimitPushdown; use datafusion::physical_optimizer::PhysicalOptimizerRule as _; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -123,6 +124,71 @@ pub fn push_aggregate_to_workers( )?)) } +/// A global (no GROUP BY) aggregate doesn't use the input ordering, but the scan still merges +/// its partitions with a SortPreservingMergeExec when the index has a sort key (e.g. picked +/// from the filters for partition pruning). Replace such merges under the aggregate with plain +/// partition coalescing: the per-row key comparisons are pure waste there, and particularly +/// bad when the filters above make the merge keys constant, turning every comparison into a +/// full-length tie. +/// +/// Restricted to aggregates without group expressions: those hold a single accumulator set per +/// partition even when later optimizations make them run per partition. A grouped hash +/// aggregate over unmerged partitions could end up with a hash table per partition, +/// multiplying its memory by the partition count. +pub fn drop_sort_merge_under_global_aggregate( + p: Arc, +) -> Result, DataFusionError> { + let Some(agg) = p.as_any().downcast_ref::() else { + return Ok(p); + }; + if !matches!(agg.input_order_mode(), InputOrderMode::Linear) + || !agg.group_expr().expr().is_empty() + { + return Ok(p); + } + // Order-sensitive aggregates (first_value, array_agg(ORDER BY), ...) stay Linear with an + // empty GROUP BY but still need their input ordered + if agg.required_input_ordering()[0].is_some() { + return Ok(p); + } + let new_input = replace_merges_with_coalesce(agg.input())?; + if Arc::ptr_eq(&new_input, agg.input()) { + return Ok(p); + } + p.with_new_children(vec![new_input]) +} + +/// Replaces sort preserving merges with plain partition coalescing, looking through the nodes +/// that don't require an input ordering of their own. +fn replace_merges_with_coalesce( + p: &Arc, +) -> Result, DataFusionError> { + let p_any = p.as_any(); + if let Some(merge) = p_any.downcast_ref::() { + if merge.fetch().is_some() { + return Ok(p.clone()); + } + let child = replace_merges_with_coalesce(merge.input())?; + return Ok(Arc::new(CoalescePartitionsExec::new(child))); + } + if p_any.is::() || p_any.is::() || p_any.is::() { + let new_children = p + .children() + .into_iter() + .map(replace_merges_with_coalesce) + .collect::, _>>()?; + if p.children() + .into_iter() + .zip(new_children.iter()) + .all(|(old, new)| Arc::ptr_eq(old, new)) + { + return Ok(p.clone()); + } + return p.clone().with_new_children(new_children); + } + Ok(p.clone()) +} + /// Transforms from: /// AggregatePartial, Sorted /// `- SortPreservingMerge @@ -991,6 +1057,167 @@ mod tests { } } + /// A global aggregate doesn't use the input order: the merge below it becomes a plain + /// coalesce, through the filter, with identical results. + #[tokio::test(flavor = "multi_thread")] + async fn drops_sort_merge_under_global_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let filter: Arc = Arc::new( + FilterExec::try_new( + Arc::new(datafusion::physical_plan::expressions::BinaryExpr::new( + col("k", &schema).unwrap(), + datafusion::logical_expr::Operator::Gt, + Arc::new(datafusion::physical_plan::expressions::Literal::new( + datafusion::scalar::ScalarValue::Int64(Some(1)), + )), + )), + merge_by_k(source), + ) + .unwrap(), + ); + let original = global_sum_aggregate(filter); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + + let agg = rewritten.as_any().downcast_ref::().unwrap(); + let filter = agg + .input() + .as_any() + .downcast_ref::() + .expect("filter must stay in place"); + assert!( + filter.input().as_any().is::(), + "the merge must become a plain coalesce" + ); + + assert_eq!( + collect_global_sum(rewritten).await, + collect_global_sum(original).await + ); + } + + /// An order-sensitive aggregate stays Linear with an empty GROUP BY but still needs its + /// input ordered, so its merge stays. + #[test] + fn keeps_sort_merge_under_order_sensitive_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let merged = merge_by_k(source); + let first = AggregateExprBuilder::new( + datafusion::functions_aggregate::array_agg::array_agg_udaf(), + vec![col("v", &schema).unwrap()], + ) + .schema(schema.clone()) + .alias("vals") + .order_by(LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", &schema).unwrap(), + )])) + .build() + .unwrap(); + let original: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(Vec::new()), + vec![Arc::new(first)], + vec![None], + merged, + schema, + ) + .unwrap(), + ); + assert!(original + .as_any() + .downcast_ref::() + .unwrap() + .required_input_ordering()[0] + .is_some()); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + /// A grouped hash aggregate over unmerged partitions could build a hash table per + /// partition, so its merge stays. + #[test] + fn keeps_sort_merge_under_grouped_hash_aggregate() { + let schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("g", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![5, 4])), + Arc::new(Int64Array::from(vec![10, 20])), + ], + ) + .unwrap(); + let source = sorted_source(&schema, vec![vec![batch.clone()], vec![batch]]); + let original = sum_aggregate(AggregateMode::Partial, "g", merge_by_k(source)); + assert!(matches!( + original + .as_any() + .downcast_ref::() + .unwrap() + .input_order_mode(), + InputOrderMode::Linear + )); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + fn global_sum_aggregate(input: Arc) -> Arc { + let schema = input.schema(); + let sum = AggregateExprBuilder::new(sum_udaf(), vec![col("v", &schema).unwrap()]) + .schema(schema.clone()) + .alias("sum_v") + .build() + .unwrap(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(Vec::new()), + vec![Arc::new(sum)], + vec![None], + input, + schema, + ) + .unwrap(), + ) + } + + /// Sums the partial states of a global aggregate across partitions. + async fn collect_global_sum(plan: Arc) -> i64 { + let session = SessionContext::new(); + let batches = collect(plan, session.task_ctx()).await.unwrap(); + batches + .iter() + .flat_map(|b| { + b.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .flatten() + .collect::>() + }) + .sum() + } + + #[test] + fn keeps_sort_merge_under_sorted_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + /// The descent looks through projections: the row limit is a plain count and commutes with /// column renames. #[test] diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs index b828c6a8e6f53..ec642715545b5 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs @@ -9,8 +9,9 @@ mod trace_data_loaded; use super::serialized_plan::PreSerializedPlan; use crate::cluster::{Cluster, WorkerPlanningParams}; use crate::queryplanner::optimizations::distributed_partial_aggregate::{ - add_limit_to_workers, ensure_partition_merge, push_aggregate_to_workers, - push_sorted_partial_aggregate_below_merge, replace_suboptimal_merge_sorts, + add_limit_to_workers, drop_sort_merge_under_global_aggregate, ensure_partition_merge, + push_aggregate_to_workers, push_sorted_partial_aggregate_below_merge, + replace_suboptimal_merge_sorts, }; use crate::queryplanner::optimizations::inline_aggregate_rewriter::replace_with_inline_aggregate; use crate::queryplanner::planning::CubeExtensionPlanner; @@ -148,6 +149,9 @@ fn pre_optimize_physical_plan( // Make the merge carry partial aggregate states instead of all raw rows let p = rewrite_physical_plan(p, &mut |p| push_sorted_partial_aggregate_below_merge(p))?; + // Global (no GROUP BY) aggregates don't need their input merged in the sort order + let p = rewrite_physical_plan(p, &mut |p| drop_sort_merge_under_global_aggregate(p))?; + // Replace sorted AggregateExec with InlineAggregateExec for better performance let p = rewrite_physical_plan(p, &mut |p| replace_with_inline_aggregate(p))?; diff --git a/rust/cubestore/cubestore/src/sql/mod.rs b/rust/cubestore/cubestore/src/sql/mod.rs index 2ab0f11032b9b..50f20eacd4ce6 100644 --- a/rust/cubestore/cubestore/src/sql/mod.rs +++ b/rust/cubestore/cubestore/src/sql/mod.rs @@ -3619,14 +3619,13 @@ mod tests { \n CoalescePartitions\ \n LinearPartialAggregate\ \n Filter\ - \n MergeSort\ - \n Scan, index: default:1:[1]:sort_on[num], fields: *\ - \n FilterByKeyRange\ - \n CheckMemoryExec\ - \n ParquetScan\ - \n FilterByKeyRange\ - \n CheckMemoryExec\ - \n ParquetScan"; + \n Scan, index: default:1:[1]:sort_on[num], fields: *\ + \n FilterByKeyRange\ + \n CheckMemoryExec\ + \n ParquetScan\ + \n FilterByKeyRange\ + \n CheckMemoryExec\ + \n ParquetScan"; let plan = pp_phys_plan_ext(plans.worker.as_ref(), &opts); let p = plan_regexp.replace_all(&plan, "ParquetScan"); println!("pp {}", p);