diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index bd14875b42504..b8186f252bce2 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -233,6 +233,26 @@ pub fn sql_tests(prefix: &str) -> Vec<(&'static str, TestFn)> { "unique_key_and_multi_partitions_hash_aggregate", unique_key_and_multi_partitions_hash_aggregate, ), + t( + "group_by_prefix_sorted_aggregate_multi_partition", + group_by_prefix_sorted_aggregate_multi_partition, + ), + t( + "group_by_prefix_limit_high_cardinality", + group_by_prefix_limit_high_cardinality, + ), + t( + "planning_aggregate_below_merge_with_limit", + planning_aggregate_below_merge_with_limit, + ), + t( + "global_aggregate_no_chunk_merge", + global_aggregate_no_chunk_merge, + ), + t( + "global_aggregate_unique_key_keeps_merge", + global_aggregate_unique_key_keeps_merge, + ), t("divide_by_zero", divide_by_zero), t( "filter_multiple_in_for_decimal", @@ -384,6 +404,11 @@ lazy_static::lazy_static! { "create_table_with_csv_no_header", "create_table_with_csv_no_header_and_delimiter", "create_table_with_csv_no_header_and_quotes", + "group_by_prefix_sorted_aggregate_multi_partition", + "group_by_prefix_limit_high_cardinality", + "planning_aggregate_below_merge_with_limit", + "global_aggregate_no_chunk_merge", + "global_aggregate_unique_key_keeps_merge", ].into_iter().map(ToOwned::to_owned).collect(); } @@ -3628,8 +3653,8 @@ async fn planning_simple(service: Box) -> Result<(), CubeError> { pp_phys_plan(p.worker.as_ref()), "InlineFinalAggregate\ \n Worker\ - \n InlinePartialAggregate\ - \n MergeSort\ + \n MergeSort\ + \n InlinePartialAggregate\ \n Union\ \n Scan, index: default:1:[1]:sort_on[id], fields: [id, amount]\ \n Sort\ @@ -7238,9 +7263,9 @@ async fn unique_key_and_multi_partitions(service: Box) -> Result< \n InlineFinalAggregate, partitions: 1\ \n MergeSort, partitions: 1\ \n Worker, partitions: 2\ - \n GlobalLimit, n: 100, partitions: 1\ - \n InlinePartialAggregate, partitions: 1\ - \n MergeSort, partitions: 1\ + \n MergeSort, partitions: 1\ + \n LocalLimit, n: 100, partitions: 2\ + \n InlinePartialAggregate, limit: 100, partitions: 2\ \n Union, partitions: 2\ \n Projection, [a, b], partitions: 1\ \n LastRowByUniqueKey, partitions: 1\ @@ -7314,6 +7339,382 @@ async fn unique_key_and_multi_partitions_hash_aggregate( Ok(()) } +/// Correctness of the sorted partial aggregate executed per partition below the merge: group +/// keys present in several partitions must combine to the same totals as without the +/// optimization, with and without LIMIT. +async fn group_by_prefix_sorted_aggregate_multi_partition( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data1 (a int, b int, val int, fval double)") + .await?; + service + .exec_query("CREATE TABLE s.Data2 (a int, b int, val int, fval double)") + .await?; + + // Group keys with `a` in 4..8 get rows from both tables, i.e. from both partitions of the + // union. + let mut raw_rows = Vec::new(); + let mut values1 = Vec::new(); + for a in 0i64..8 { + for b in 0i64..3 { + for r in 0i64..2 { + let val = a * 31 + b * 17 + r; + values1.push(format!("({}, {}, {}, {}.5)", a, b, val, val)); + raw_rows.push((a, b, val)); + } + } + } + let mut values2 = Vec::new(); + for a in 4i64..12 { + for b in 0i64..3 { + for r in 0i64..2 { + let val = a * 13 + b * 7 + r; + values2.push(format!("({}, {}, {}, {}.5)", a, b, val, val)); + raw_rows.push((a, b, val)); + } + } + } + service + .exec_query(&format!( + "INSERT INTO s.Data1 (a, b, val, fval) VALUES {}", + values1.join(", ") + )) + .await?; + service + .exec_query(&format!( + "INSERT INTO s.Data2 (a, b, val, fval) VALUES {}", + values2.join(", ") + )) + .await?; + + // Expected aggregates per group; fval = val + 0.5 keeps float sums exactly representable, + // so the assertions stay byte-exact regardless of the summation order. + let mut groups = std::collections::BTreeMap::<(i64, i64), (i64, i64, i64, i64)>::new(); + for (a, b, val) in raw_rows { + let group = groups.entry((a, b)).or_insert((0, i64::MAX, i64::MIN, 0)); + group.0 += val; + group.1 = group.1.min(val); + group.2 = group.2.max(val); + group.3 += 1; + } + let group_row = |((a, b), (sum, min, max, count)): (&(i64, i64), &(i64, i64, i64, i64))| { + vec![ + TableValue::Int(*a), + TableValue::Int(*b), + TableValue::Int(*sum), + TableValue::Int(*min), + TableValue::Int(*max), + TableValue::Float((*sum as f64 + 0.5 * *count as f64).into()), + ] + }; + let expected: Vec> = groups.iter().map(group_row).collect(); + + let query = "SELECT a, b, sum(val), min(val), max(val), sum(fval) FROM (\ + SELECT * FROM s.Data1 UNION ALL SELECT * FROM s.Data2\ + ) `t` GROUP BY 1, 2"; + + let full = service + .exec_query(&format!("{} ORDER BY 1, 2", query)) + .await?; + assert_eq!(to_rows(&full), expected); + + // LIMIT must return the same prefix of the full result + let limited = service + .exec_query(&format!("{} ORDER BY 1, 2 LIMIT 5", query)) + .await?; + assert_eq!(to_rows(&limited), expected[..5]); + + // DESC ordering takes the tail groups + let tail = service + .exec_query(&format!("{} ORDER BY 1 DESC, 2 DESC LIMIT 5", query)) + .await?; + let expected_tail: Vec<_> = expected.iter().rev().take(5).cloned().collect(); + assert_eq!(to_rows(&tail), expected_tail); + + // ORDER BY an aggregate must not use the group-order shortcut + let by_sum = service + .exec_query(&format!("{} ORDER BY 3, 1, 2 LIMIT 4", query)) + .await?; + let mut by_sum_keys: Vec<(i64, i64, i64)> = groups + .iter() + .map(|((a, b), (sum, ..))| (*sum, *a, *b)) + .collect(); + by_sum_keys.sort(); + let expected_by_sum: Vec> = by_sum_keys + .iter() + .take(4) + .map(|(_, a, b)| group_row(((&(*a, *b)), &groups[&(*a, *b)]))) + .collect(); + assert_eq!(to_rows(&by_sum), expected_by_sum); + Ok(()) +} + +/// LIMIT short-circuit correctness on group counts below and above the execution batch size +/// (4096), with duplicate group keys across partitions. +async fn group_by_prefix_limit_high_cardinality( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data1 (a int, val int)") + .await?; + service + .exec_query("CREATE TABLE s.Data2 (a int, val int)") + .await?; + + // 7500 groups in total, above the 4096 execution batch size; keys 2500..5000 are present + // in both tables. + for chunk in (0i64..5000).collect::>().chunks(1000) { + let values = chunk + .iter() + .map(|a| format!("({}, {})", a, a * 3 + 1)) + .join(", "); + service + .exec_query(&format!("INSERT INTO s.Data1 (a, val) VALUES {}", values)) + .await?; + } + for chunk in (2500i64..7500).collect::>().chunks(1000) { + let values = chunk + .iter() + .map(|a| format!("({}, {})", a, a * 5 + 2)) + .join(", "); + service + .exec_query(&format!("INSERT INTO s.Data2 (a, val) VALUES {}", values)) + .await?; + } + + let expected = |range: std::ops::Range| -> Vec> { + range + .map(|a| { + let mut sum = 0; + if a < 5000 { + sum += a * 3 + 1; + } + if a >= 2500 { + sum += a * 5 + 2; + } + vec![TableValue::Int(a), TableValue::Int(sum)] + }) + .collect() + }; + + let query = "SELECT a, sum(val) FROM (\ + SELECT * FROM s.Data1 UNION ALL SELECT * FROM s.Data2\ + ) `t` GROUP BY 1"; + + // LIMIT far below the batch size + let r = service + .exec_query(&format!("{} ORDER BY 1 LIMIT 10", query)) + .await?; + assert_eq!(to_rows(&r), expected(0..10)); + + // LIMIT above the batch size + let r = service + .exec_query(&format!("{} ORDER BY 1 LIMIT 5000", query)) + .await?; + assert_eq!(to_rows(&r), expected(0..5000)); + + // DESC takes the last groups (per-partition tail path) + let r = service + .exec_query(&format!("{} ORDER BY 1 DESC LIMIT 10", query)) + .await?; + let mut expected_tail = expected(7490..7500); + expected_tail.reverse(); + assert_eq!(to_rows(&r), expected_tail); + + // No limit: all the groups + let r = service.exec_query(&format!("{} ORDER BY 1", query)).await?; + assert_eq!(to_rows(&r), expected(0..7500)); + Ok(()) +} + +/// The sorted partial aggregate runs per partition below the merge; the worker limit becomes a +/// group limit on the aggregate plus a widened row budget on the merge (duplicate group keys +/// from different partitions make a plain row limit incorrect). +async fn planning_aggregate_below_merge_with_limit( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Orders(a int, b int, amount int)") + .await?; + + let p = service + .plan_query( + "SELECT a, b, SUM(amount) FROM (\ + SELECT * FROM s.Orders UNION ALL SELECT * FROM s.Orders\ + ) `t` GROUP BY 1, 2", + ) + .await?; + assert_eq!( + pp_phys_plan(p.worker.as_ref()), + "InlineFinalAggregate\ + \n Worker\ + \n MergeSort\ + \n InlinePartialAggregate\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" + ); + + let p = service + .plan_query( + "SELECT a, b, SUM(amount) FROM (\ + SELECT * FROM s.Orders UNION ALL SELECT * FROM s.Orders\ + ) `t` GROUP BY 1, 2 ORDER BY 1, 2 LIMIT 5", + ) + .await?; + assert_eq!( + pp_phys_plan(p.worker.as_ref()), + "Sort, fetch: 5\ + \n InlineFinalAggregate\ + \n Worker\ + \n MergeSort\ + \n LocalLimit, n: 5\ + \n InlinePartialAggregate, limit: 5\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" + ); + + // Reverse limit: a per-partition tail below the merge instead of a group limit (the last + // groups are unknown until the input ends, so the aggregate can't stop early) + let p = service + .plan_query( + "SELECT a, b, SUM(amount) FROM (\ + SELECT * FROM s.Orders UNION ALL SELECT * FROM s.Orders\ + ) `t` GROUP BY 1, 2 ORDER BY 1 DESC, 2 DESC LIMIT 5", + ) + .await?; + assert_eq!( + pp_phys_plan(p.worker.as_ref()), + "Sort, fetch: 5\ + \n InlineFinalAggregate\ + \n Worker\ + \n MergeSort\ + \n TailLimit, n: 5\ + \n InlinePartialAggregate\ + \n Union\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty\ + \n Scan, index: default:1:[1]:sort_on[a, b], fields: *\ + \n Sort\ + \n Empty" + ); + Ok(()) +} + +/// A no-GROUP BY (hash) aggregate over a multi-chunk scan: the chunks must be coalesced, not +/// merge-sorted -- the aggregate doesn't use the order, and the equality filters on the sort key +/// prefix make every merge comparison a full-length tie. +async fn global_aggregate_no_chunk_merge(service: Box) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query( + "CREATE TABLE s.Batch (tenant_id int, deployment_id int, ts timestamp, hits int, hits_failed int)", + ) + .await?; + service + .exec_query( + "CREATE TABLE s.Stream (tenant_id int, deployment_id int, ts timestamp, hits int, hits_failed int)", + ) + .await?; + + // Several inserts = several chunks, like a streaming table + for i in 0..3 { + let values = (0..50) + .map(|j| { + format!( + "(25358, 5, '2026-06-03T0{}:00:{:02}.000', {}, {})", + i, + j, + j, + j % 3 + ) + }) + .join(", "); + service + .exec_query(&format!( + "INSERT INTO s.Stream (tenant_id, deployment_id, ts, hits, hits_failed) VALUES {}", + values + )) + .await?; + } + + let query = "SELECT sum(hits) hits, sum(hits_failed) hits_failed FROM (\ + SELECT * FROM s.Batch WHERE 1 = 0 \ + UNION ALL \ + SELECT * FROM s.Stream\ + ) `t` WHERE tenant_id = 25358 AND deployment_id = 5 \ + AND ts >= to_timestamp('2026-06-03T00:00:00.000') \ + AND ts <= to_timestamp('2026-06-03T23:59:59.999') \ + LIMIT 10000"; + + let r = service.exec_query(query).await?; + assert_eq!(to_rows(&r), rows(&[(3675, 147)])); + + let p = service.plan_query(query).await?; + let worker_plan = pp_phys_plan(p.worker.as_ref()); + assert!( + !worker_plan.contains("MergeSort"), + "hash aggregate must not merge-sort the chunks:\n{}", + worker_plan + ); + let coalesce_pos = worker_plan.find("CoalescePartitions"); + let aggregate_pos = worker_plan.find("LinearPartialAggregate"); + assert!( + coalesce_pos.is_some() && aggregate_pos.is_some() && coalesce_pos < aggregate_pos, + "the per-partition aggregation must be coalesced above:\n{}", + worker_plan + ); + Ok(()) +} + +/// A unique key table deduplicates row versions via LastRowByUniqueKey, which needs its input +/// merge-sorted to keep the versions of a key adjacent. That merge must survive the +/// no-ordering-needed rewrite under a global aggregate. +async fn global_aggregate_unique_key_keeps_merge( + service: Box, +) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Versions (a int, b int, val int) unique key (a, b)") + .await?; + + service + .exec_query("INSERT INTO s.Versions (a, b, val, __seq) VALUES (1, 1, 10, 1), (2, 2, 20, 2)") + .await?; + // A newer version of (1, 1) in another chunk + service + .exec_query("INSERT INTO s.Versions (a, b, val, __seq) VALUES (1, 1, 30, 3)") + .await?; + + let query = "SELECT sum(val) FROM s.Versions"; + let r = service.exec_query(query).await?; + // Only the last version of each key counts: 30 + 20, not 10 + 30 + 20 + assert_eq!(to_rows(&r), rows(&[(50)])); + + let p = service.plan_query(query).await?; + let worker_plan = pp_phys_plan(p.worker.as_ref()); + assert!( + worker_plan.contains("LastRowByUniqueKey") && worker_plan.contains("MergeSort"), + "the deduplicating merge must stay in place:\n{}", + worker_plan + ); + Ok(()) +} + async fn divide_by_zero(service: Box) -> Result<(), CubeError> { service.exec_query("CREATE SCHEMA s").await?; service.exec_query("CREATE TABLE s.t(i int, z int)").await?; @@ -8053,12 +8454,12 @@ async fn build_range_end(service: Box) -> Result<(), CubeError> { Ok(()) } -async fn assert_limit_pushdown_using_search_string( +async fn assert_limit_pushdown_using_search_strings( service: &Box, query: &str, expected_index: Option<&str>, is_limit_expected: bool, - search_string: &str, + search_strings: &[&str], ) -> Result, CubeError> { let res = service .exec_query(&format!("EXPLAIN ANALYZE {}", query)) @@ -8073,18 +8474,17 @@ async fn assert_limit_pushdown_using_search_string( ))); } } - let expected_limit = search_string; if is_limit_expected { - if !s.contains(expected_limit) { + if !search_strings.iter().any(|expected| s.contains(expected)) { return Err(CubeError::internal(format!( "{} expected but not found", - expected_limit + search_strings.join(" or ") ))); } - } else if s.contains(expected_limit) { + } else if let Some(found) = search_strings.iter().find(|e| s.contains(*e)) { return Err(CubeError::internal(format!( "{} unexpected but found", - expected_limit + found ))); } } @@ -8095,6 +8495,23 @@ async fn assert_limit_pushdown_using_search_string( Ok(res.get_rows().clone()) } +async fn assert_limit_pushdown_using_search_string( + service: &Box, + query: &str, + expected_index: Option<&str>, + is_limit_expected: bool, + search_string: &str, +) -> Result, CubeError> { + assert_limit_pushdown_using_search_strings( + service, + query, + expected_index, + is_limit_expected, + &[search_string], + ) + .await +} + async fn assert_limit_pushdown( service: &Box, query: &str, @@ -8102,15 +8519,17 @@ async fn assert_limit_pushdown( is_limit_expected: bool, is_tail_limit: bool, ) -> Result, CubeError> { - assert_limit_pushdown_using_search_string( + assert_limit_pushdown_using_search_strings( service, query, expected_index, is_limit_expected, if is_tail_limit { - "TailLimit" + &["TailLimit"] } else { - "GlobalLimit" + // The worker limit is either a plain row limit or, for a partial aggregate running + // per partition below the merge, a group limit on the aggregate. + &["GlobalLimit", "InlinePartialAggregate, limit:"] }, ) .await diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs index 5b2e6c4c38df1..bee04149514cc 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs @@ -43,6 +43,11 @@ pub(crate) struct InlineAggregateStream { batch_size: usize, + /// Per-partition limit on the number of emitted groups, see [`InlineAggregateExec::limit`] + limit: Option, + + groups_emitted: usize, + exec_state: ExecutionState, input_done: bool, @@ -100,6 +105,8 @@ impl InlineAggregateStream { group_by: agg_group_by, exec_state, batch_size, + limit: agg.limit(), + groups_emitted: 0, current_group_indices, group_values, input_done: false, @@ -175,26 +182,33 @@ impl Stream for InlineAggregateStream { loop { match &self.exec_state { ExecutionState::ReadingInput => { + // All needed groups are emitted, skip the rest of the input + if self.limit_reached() { + self.exec_state = ExecutionState::Done; + continue; + } + + // Drain groups accumulated beyond the emit threshold (a single input batch + // can bring many) before reading further input + match self.emit_early_if_ready() { + Ok(Some(batch)) => { + self.exec_state = ExecutionState::ProducingOutput(batch); + continue; + } + Ok(None) => { + // Not enough groups, read further + } + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + } + match ready!(self.input.poll_next_unpin(cx)) { - // New input batch to aggregate + // New input batch to aggregate; emitting happens at the top of the loop Some(Ok(batch)) => { - // Aggregate the batch if let Err(e) = self.group_aggregate_batch(batch) { return Poll::Ready(Some(Err(e))); } - - // Try to emit a batch if we have enough groups - match self.emit_early_if_ready() { - Ok(Some(batch)) => { - self.exec_state = ExecutionState::ProducingOutput(batch); - } - Ok(None) => { - // Not enough groups yet, continue reading - } - Err(e) => { - return Poll::Ready(Some(Err(e))); - } - } } // Error from input stream @@ -202,11 +216,11 @@ impl Stream for InlineAggregateStream { return Poll::Ready(Some(Err(e))); } - // Input stream exhausted - emit all remaining groups + // Input stream exhausted - emit the remaining groups, up to the limit None => { self.input_done = true; - match self.emit(EmitTo::All) { + match self.emit_remaining() { Ok(Some(batch)) => { self.exec_state = ExecutionState::ProducingOutput(batch); } @@ -244,9 +258,6 @@ impl Stream for InlineAggregateStream { } impl InlineAggregateStream { - /// Emit groups based on EmitTo strategy. - /// - /// Returns None if there are no groups to emit. /// Emit groups based on EmitTo strategy. /// /// Returns None if there are no groups to emit. @@ -283,25 +294,61 @@ impl InlineAggregateStream { Ok(Some(batch)) } - /// Check if we have enough groups to emit a batch, keeping the last (potentially incomplete) group. - /// - /// For sorted aggregation, we emit batches of size batch_size when we have accumulated - /// more than batch_size groups. We always keep the last group as it may continue in the next input batch. - fn should_emit_early(&self) -> bool { - // Need at least (batch_size + 1) groups to emit batch_size and keep 1 - self.group_values.len() > self.batch_size + fn limit_reached(&self) -> bool { + self.limit.is_some_and(|limit| self.groups_emitted >= limit) + } + + /// How many groups to emit in the next early batch: full batches until the limit (if any) + /// leaves fewer groups to emit. + fn emit_early_threshold(&self) -> usize { + match self.limit { + Some(limit) => self + .batch_size + .min(limit.saturating_sub(self.groups_emitted)), + None => self.batch_size, + } } /// Emit a batch of groups if we have enough accumulated, keeping the last group. /// + /// For sorted aggregation, we emit when we have accumulated more than threshold groups: the + /// last group is always kept as it may continue in the next input batch, so only closed + /// groups are emitted. + /// /// Returns Some(batch) if emitted, None otherwise. fn emit_early_if_ready(&mut self) -> DFResult> { - if !self.should_emit_early() { + let threshold = self.emit_early_threshold(); + // Need at least (threshold + 1) groups to emit threshold closed groups and keep 1. + // The threshold == 0 check is defensive: the poll loop checks limit_reached() before + // calling this, so the threshold is at least 1 there. + if threshold == 0 || self.group_values.len() <= threshold { return Ok(None); } - // Emit exactly batch_size groups, keeping the rest (including last incomplete group) - self.emit(EmitTo::First(self.batch_size)) + let batch = self.emit(EmitTo::First(threshold))?; + self.groups_emitted += threshold; + Ok(batch) + } + + /// Emit the groups left at the end of the input: all of them are closed at this point, but + /// no more than the limit allows. + fn emit_remaining(&mut self) -> DFResult> { + let len = self.group_values.len(); + let emit_count = match self.limit { + Some(limit) => len.min(limit.saturating_sub(self.groups_emitted)), + None => len, + }; + if emit_count == 0 { + return Ok(None); + } + let emit_to = if emit_count < len { + EmitTo::First(emit_count) + } else { + EmitTo::All + }; + let batch = self.emit(emit_to)?; + self.groups_emitted += emit_count; + Ok(batch) } fn group_aggregate_batch(&mut self, batch: RecordBatch) -> DFResult<()> { diff --git a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs index e8ea319ec4605..15099c6a69d8e 100644 --- a/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/inline_aggregate/mod.rs @@ -40,7 +40,8 @@ pub struct InlineAggregateExec { aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, - /// Set if the output of this aggregation is truncated by a upstream sort/limit clause + /// Per-partition limit on the number of emitted groups: each partition emits at most this + /// many first (in group order) complete groups and stops reading its input limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, @@ -111,6 +112,17 @@ impl InlineAggregateExec { self.limit } + /// Returns a copy of this aggregate with the per-partition group limit set. Each partition + /// emits at most `limit` first (in group order) complete groups and stops reading its input. + /// + /// Plan properties and statistics are intentionally not narrowed by the limit: the limit is + /// set after all the optimizers that could consume them have run. + pub fn with_limit(&self, limit: Option) -> Self { + let mut result = self.clone(); + result.limit = limit; + result + } + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -289,3 +301,281 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::BinaryView ) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int64Array, RecordBatch}; + use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::common::arrow::compute::concat_batches; + use datafusion::functions_aggregate::sum::sum_udaf; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::col; + use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion::physical_plan::collect; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::prelude::{SessionConfig, SessionContext}; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use futures::StreamExt; + use std::sync::atomic::{AtomicUsize, Ordering}; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])) + } + + fn make_batch(schema: &SchemaRef, rows: &[(i64, i64)]) -> RecordBatch { + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.0))), + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.1))), + ], + ) + .unwrap() + } + + fn sorted_source( + schema: &SchemaRef, + partitions: Vec>, + ) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", schema).unwrap(), + )]); + let source = MemorySourceConfig::try_new(&partitions, schema.clone(), None) + .unwrap() + .try_with_sort_information(vec![ordering]) + .unwrap(); + Arc::new(DataSourceExec::new(Arc::new(source))) + } + + fn partial_sum_inline_aggregate( + input: Arc, + limit: Option, + ) -> Arc { + let schema = input.schema(); + let group_by = + PhysicalGroupBy::new_single(vec![(col("k", &schema).unwrap(), "k".to_string())]); + let sum = AggregateExprBuilder::new(sum_udaf(), vec![col("v", &schema).unwrap()]) + .schema(schema.clone()) + .alias("sum_v") + .build() + .unwrap(); + let agg = AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![Arc::new(sum)], + vec![None], + input, + schema, + ) + .unwrap(); + assert!( + matches!(agg.input_order_mode(), InputOrderMode::Sorted), + "test setup must produce a sorted aggregate" + ); + let inline = InlineAggregateExec::try_new_from_aggregate(&agg).unwrap(); + Arc::new(inline.with_limit(limit)) + } + + fn run(plan: Arc, batch_size: usize) -> Vec<(i64, i64)> { + let session = + SessionContext::new_with_config(SessionConfig::new().with_batch_size(batch_size)); + let batches = futures::executor::block_on(collect(plan, session.task_ctx())).unwrap(); + let schema = batches[0].schema(); + let batch = concat_batches(&schema, &batches).unwrap(); + let keys = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let sums = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + keys.iter() + .zip(sums.iter()) + .map(|(k, s)| (k.unwrap(), s.unwrap())) + .collect() + } + + /// A group continuing in the next input batch must not be emitted early with a partial sum. + #[test] + fn limit_emits_only_closed_groups() { + let schema = test_schema(); + let input = sorted_source( + &schema, + vec![vec![ + make_batch(&schema, &[(1, 10), (2, 20), (3, 30), (3, 31)]), + make_batch(&schema, &[(3, 32), (4, 40)]), + ]], + ); + let agg = partial_sum_inline_aggregate(input, Some(3)); + assert_eq!(run(agg, 4096), vec![(1, 10), (2, 20), (3, 93)]); + } + + #[test] + fn limit_results_match_no_limit_prefix() { + let schema = test_schema(); + let rows: Vec<(i64, i64)> = (0..1000).map(|i| (i / 3, i)).collect(); + let batches: Vec = rows.chunks(97).map(|c| make_batch(&schema, c)).collect(); + let source = sorted_source(&schema, vec![batches]); + + let no_limit = run(partial_sum_inline_aggregate(source.clone(), None), 4096); + let limited = run(partial_sum_inline_aggregate(source, Some(5)), 4096); + assert_eq!(limited, no_limit[..5]); + } + + #[test] + fn limit_larger_than_group_count_emits_all() { + let schema = test_schema(); + let input = sorted_source( + &schema, + vec![vec![make_batch(&schema, &[(1, 10), (2, 20), (3, 30)])]], + ); + let agg = partial_sum_inline_aggregate(input, Some(100)); + assert_eq!(run(agg, 4096), vec![(1, 10), (2, 20), (3, 30)]); + } + + /// Emitting in batch_size chunks until the limit is reached. + #[test] + fn limit_above_batch_size_emits_incrementally() { + let schema = test_schema(); + let rows: Vec<(i64, i64)> = (0..16).map(|i| (i / 2, i)).collect(); + let batches: Vec = rows.chunks(3).map(|c| make_batch(&schema, c)).collect(); + let source = sorted_source(&schema, vec![batches]); + + let no_limit = run(partial_sum_inline_aggregate(source.clone(), None), 2); + let limited = run(partial_sum_inline_aggregate(source, Some(5)), 2); + assert_eq!(limited, no_limit[..5]); + } + + /// A single input batch can bring more groups than the limit; the stream must still emit + /// exactly `limit` groups, draining the backlog in emit threshold chunks. + #[test] + fn limit_holds_when_one_batch_overshoots_it() { + let schema = test_schema(); + let rows: Vec<(i64, i64)> = (0..20).map(|k| (k, k + 100)).collect(); + let input = sorted_source(&schema, vec![vec![make_batch(&schema, &rows)]]); + let agg = partial_sum_inline_aggregate(input, Some(10)); + assert_eq!( + run(agg, 4), + (0..10).map(|k| (k, k + 100)).collect::>() + ); + } + + /// Wraps a plan and counts batches its streams produce. + #[derive(Debug)] + struct CountingExec { + inner: Arc, + batches_polled: Arc, + } + + impl DisplayAs for CountingExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "CountingExec") + } + } + + impl ExecutionPlan for CountingExec { + fn name(&self) -> &'static str { + "CountingExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(CountingExec { + inner: children[0].clone(), + batches_polled: self.batches_polled.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let stream = self.inner.execute(partition, context)?; + let counter = self.batches_polled.clone(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + stream.schema(), + stream.inspect(move |_| { + counter.fetch_add(1, Ordering::SeqCst); + }), + ))) + } + } + + /// Once the limit is reached the aggregate must stop reading its input, so a downstream + /// LIMIT short-circuits the scan. + #[test] + fn limit_stops_reading_input() { + let schema = test_schema(); + let batches: Vec = (0..100) + .map(|i| { + let rows: Vec<(i64, i64)> = (0..10).map(|j| (i * 10 + j, 1)).collect(); + make_batch(&schema, &rows) + }) + .collect(); + let source = sorted_source(&schema, vec![batches]); + let batches_polled = Arc::new(AtomicUsize::new(0)); + let counting = Arc::new(CountingExec { + inner: source, + batches_polled: batches_polled.clone(), + }); + + let result = run(partial_sum_inline_aggregate(counting, Some(5)), 4096); + assert_eq!(result.len(), 5); + assert!( + batches_polled.load(Ordering::SeqCst) < 10, + "aggregate must stop polling input after the limit is reached, polled {} batches", + batches_polled.load(Ordering::SeqCst) + ); + } + + /// When one batch brings enough groups to cover the limit, the stream must drain them + /// without reading further input. + #[test] + fn limit_drains_backlog_without_reading_input() { + let schema = test_schema(); + let batches: Vec = (0..100) + .map(|i| { + let rows: Vec<(i64, i64)> = (0..20).map(|j| (i * 20 + j, 1)).collect(); + make_batch(&schema, &rows) + }) + .collect(); + let source = sorted_source(&schema, vec![batches]); + let batches_polled = Arc::new(AtomicUsize::new(0)); + let counting = Arc::new(CountingExec { + inner: source, + batches_polled: batches_polled.clone(), + }); + + // The first batch alone brings 20 groups > limit 10 + let result = run(partial_sum_inline_aggregate(counting, Some(10)), 4); + assert_eq!(result.len(), 10); + assert_eq!( + batches_polled.load(Ordering::SeqCst), + 1, + "the first input batch covers the limit, no further reads needed" + ); + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs index e670d4be6e945..8fbe92c6fa55f 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/distributed_partial_aggregate.rs @@ -1,4 +1,5 @@ use crate::cluster::WorkerPlanningParams; +use crate::queryplanner::inline_aggregate::{InlineAggregateExec, InlineAggregateMode}; use crate::queryplanner::planning::WorkerExec; use crate::queryplanner::query_executor::ClusterSendExec; use crate::queryplanner::tail_limit::TailLimitExec; @@ -12,12 +13,15 @@ use datafusion::physical_optimizer::limit_pushdown::LimitPushdown; use datafusion::physical_optimizer::PhysicalOptimizerRule as _; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion::physical_plan::limit::GlobalLimitExec; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, PhysicalExpr}; +use datafusion::physical_plan::{ + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, +}; use itertools::Itertools as _; use std::collections::HashSet; use std::sync::Arc; @@ -120,6 +124,137 @@ pub fn push_aggregate_to_workers( )?)) } +/// A global (no GROUP BY) aggregate doesn't use the input ordering, but the scan still merges +/// its partitions with a SortPreservingMergeExec when the index has a sort key (e.g. picked +/// from the filters for partition pruning). Replace such merges under the aggregate with plain +/// partition coalescing: the per-row key comparisons are pure waste there, and particularly +/// bad when the filters above make the merge keys constant, turning every comparison into a +/// full-length tie. +/// +/// Restricted to aggregates without group expressions: those hold a single accumulator set per +/// partition even when later optimizations make them run per partition. A grouped hash +/// aggregate over unmerged partitions could end up with a hash table per partition, +/// multiplying its memory by the partition count. +pub fn drop_sort_merge_under_global_aggregate( + p: Arc, +) -> Result, DataFusionError> { + let Some(agg) = p.as_any().downcast_ref::() else { + return Ok(p); + }; + if !matches!(agg.input_order_mode(), InputOrderMode::Linear) + || !agg.group_expr().expr().is_empty() + { + return Ok(p); + } + // Order-sensitive aggregates (first_value, array_agg(ORDER BY), ...) stay Linear with an + // empty GROUP BY but still need their input ordered + if agg.required_input_ordering()[0].is_some() { + return Ok(p); + } + let new_input = replace_merges_with_coalesce(agg.input())?; + if Arc::ptr_eq(&new_input, agg.input()) { + return Ok(p); + } + p.with_new_children(vec![new_input]) +} + +/// Replaces sort preserving merges with plain partition coalescing, looking through the nodes +/// that don't require an input ordering of their own. +fn replace_merges_with_coalesce( + p: &Arc, +) -> Result, DataFusionError> { + let p_any = p.as_any(); + if let Some(merge) = p_any.downcast_ref::() { + if merge.fetch().is_some() { + return Ok(p.clone()); + } + let child = replace_merges_with_coalesce(merge.input())?; + return Ok(Arc::new(CoalescePartitionsExec::new(child))); + } + if p_any.is::() || p_any.is::() || p_any.is::() { + let new_children = p + .children() + .into_iter() + .map(replace_merges_with_coalesce) + .collect::, _>>()?; + if p.children() + .into_iter() + .zip(new_children.iter()) + .all(|(old, new)| Arc::ptr_eq(old, new)) + { + return Ok(p.clone()); + } + return p.clone().with_new_children(new_children); + } + Ok(p.clone()) +} + +/// Transforms from: +/// AggregatePartial, Sorted +/// `- SortPreservingMerge +/// `- source(N partitions) +/// to: +/// SortPreservingMerge +/// `- AggregatePartial, Sorted (executed per partition) +/// `- source(N partitions) +/// +/// The merge then carries one row per group per partition instead of all raw rows. Duplicate +/// group keys from different partitions are adjacent in the merged stream and get combined by +/// the Final aggregate, the same way partial states from different workers are. +/// +/// Only sorted (streaming) partial aggregates are pushed: they hold O(1) accumulators per +/// partition, while a hash aggregate holds O(num_groups) and would multiply its memory usage by +/// the partition count. +pub fn push_sorted_partial_aggregate_below_merge( + p: Arc, +) -> Result, DataFusionError> { + let Some(agg) = p.as_any().downcast_ref::() else { + return Ok(p); + }; + if *agg.mode() != AggregateMode::Partial + || !matches!(agg.input_order_mode(), InputOrderMode::Sorted) + // Restrict to aggregates convertible to InlineAggregateExec (no grouping sets) + || !agg.group_expr().is_single() + { + return Ok(p); + } + let Some(merge) = agg + .input() + .as_any() + .downcast_ref::() + else { + return Ok(p); + }; + if merge.fetch().is_some() { + return Ok(p); + } + let merge_input = merge.input(); + if merge_input.output_partitioning().partition_count() <= 1 { + return Ok(p); + } + + let new_agg = AggregateExec::try_new( + AggregateMode::Partial, + agg.group_expr().clone(), + agg.aggr_expr().to_vec(), + agg.filter_expr().to_vec(), + merge_input.clone(), + agg.input_schema(), + )?; + // Per-partition input must still be sorted on the group keys, otherwise the aggregate + // becomes hash-based and must stay above the merge. + if !matches!(new_agg.input_order_mode(), InputOrderMode::Sorted) { + return Ok(p); + } + let Some(ordering) = new_agg.properties().output_ordering().cloned() else { + return Ok(p); + }; + Ok(Arc::new(SortPreservingMergeExec::new( + ordering, + Arc::new(new_agg), + ))) +} + pub fn ensure_partition_merge_helper( p: Arc, new_child: &mut bool, @@ -224,6 +359,29 @@ pub fn add_limit_to_workers( let Some((limit, reverse)) = limit_and_reverse else { return Ok(p); }; + + // A row limit must not be placed above a merge of per-partition aggregates: the merged + // stream carries duplicate group keys from different partitions, and cutting it by rows + // could cut off part of some group's partial states, silently corrupting that group's + // total. Instead the limit descends through the merges and lands directly above the first + // aggregate, where one row is one complete (within its partition) group: the union of + // per-partition first (last, for reverse) `limit` groups covers the global first (last) + // `limit` groups, and groups beyond that arrive complete or not at all -- the router + // orders by the group key, not by the totals, and its own limit drops them. + if first_aggregate_below_merges(input).is_some() { + let new_input = limit_above_first_aggregate(input, limit, reverse)?; + return p.with_new_children(vec![new_input]); + } + + // Tripwire for shapes the descent doesn't recognize: a per-partition partial aggregate in + // the subtree means a duplicate-bearing merge somewhere above it, and the row limit below + // (LimitPushdown descends through projections into merges) could land on it. The worker + // limit is an optimization, skipping it is always correct. + if contains_multi_partition_partial_aggregate(input) { + return Ok(p); + } + + // No aggregation: plain rows, each one self-contained, a row limit is exact anywhere. if reverse { let limit = Arc::new(TailLimitExec::new(input.clone(), limit)); p.with_new_children(vec![limit]) @@ -234,6 +392,84 @@ pub fn add_limit_to_workers( } } +/// The first aggregate reachable from `p` looking only through sort preserving merges and +/// projections. +fn first_aggregate_below_merges(mut p: &Arc) -> Option<&Arc> { + loop { + if let Some(merge) = p.as_any().downcast_ref::() { + p = merge.input(); + } else if let Some(projection) = p.as_any().downcast_ref::() { + p = projection.input(); + } else if p.as_any().is::() || p.as_any().is::() { + return Some(p); + } else { + return None; + } + } +} + +/// Rebuilds the chain of merges and projections with a per-partition row limit placed directly +/// above the first aggregate. Must only be called when [first_aggregate_below_merges] found one. +fn limit_above_first_aggregate( + p: &Arc, + limit: usize, + reverse: bool, +) -> Result, DataFusionError> { + if let Some(merge) = p.as_any().downcast_ref::() { + let child = limit_above_first_aggregate(merge.input(), limit, reverse)?; + return Ok(Arc::new( + SortPreservingMergeExec::new(merge.expr().clone(), child).with_fetch(merge.fetch()), + )); + } + if let Some(projection) = p.as_any().downcast_ref::() { + // The limit is a plain row count, it commutes with column renames + let child = limit_above_first_aggregate(projection.input(), limit, reverse)?; + return Ok(Arc::new(ProjectionExec::try_new( + projection.expr().to_vec(), + child, + )?)); + } + + // The sorted streaming aggregate can additionally stop reading its input early once it has + // emitted `limit` groups; for reverse the last groups are unknown until the input ends, so + // there is nothing to pass into the aggregate. + let mut node = p.clone(); + if !reverse { + if let Some(agg) = p.as_any().downcast_ref::() { + if *agg.mode() == InlineAggregateMode::Partial { + let agg_limit = agg.limit().map_or(limit, |l| l.min(limit)); + node = Arc::new(agg.with_limit(Some(agg_limit))); + } + } + } + + Ok(if reverse { + Arc::new(TailLimitExec::new(node, limit)) + } else if node.output_partitioning().partition_count() == 1 { + Arc::new(GlobalLimitExec::new(node, 0, Some(limit))) + } else { + Arc::new(LocalLimitExec::new(node, limit)) + }) +} + +/// True if the plan contains a partial aggregate executed per partition: its merged output +/// carries duplicate group keys, so a row limit above such a merge is not group-aligned. +fn contains_multi_partition_partial_aggregate(p: &Arc) -> bool { + let is_one = if let Some(agg) = p.as_any().downcast_ref::() { + *agg.mode() == InlineAggregateMode::Partial + } else if let Some(agg) = p.as_any().downcast_ref::() { + *agg.mode() == AggregateMode::Partial + } else { + false + }; + if is_one && p.output_partitioning().partition_count() > 1 { + return true; + } + p.children() + .into_iter() + .any(contains_multi_partition_partial_aggregate) +} + /// Because we disable `EnforceDistribution`, and because we add `SortPreservingMergeExec` in /// `ensure_partition_merge_with_acceptable_parent` so that Sorted ("inplace") aggregates work /// properly (which reduces memory usage), we in some cases have unnecessary @@ -435,3 +671,657 @@ fn replace_columns( .data, ) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int64Array, RecordBatch}; + use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::functions_aggregate::sum::sum_udaf; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::col; + use datafusion::physical_expr::PhysicalSortExpr; + use datafusion::physical_plan::aggregates::PhysicalGroupBy; + use datafusion::physical_plan::collect; + use datafusion::prelude::SessionContext; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use std::collections::BTreeMap; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])) + } + + fn make_batch(schema: &SchemaRef, rows: &[(i64, i64)]) -> RecordBatch { + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.0))), + Arc::new(Int64Array::from_iter_values(rows.iter().map(|r| r.1))), + ], + ) + .unwrap() + } + + /// Memory source with each partition sorted by `k`. + fn sorted_source( + schema: &SchemaRef, + partitions: Vec>, + ) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", schema).unwrap(), + )]); + let source = MemorySourceConfig::try_new(&partitions, schema.clone(), None) + .unwrap() + .try_with_sort_information(vec![ordering]) + .unwrap(); + Arc::new(DataSourceExec::new(Arc::new(source))) + } + + fn merge_by_k(input: Arc) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", &input.schema()).unwrap(), + )]); + Arc::new(SortPreservingMergeExec::new(ordering, input)) + } + + fn sum_aggregate( + mode: AggregateMode, + group_col: &str, + input: Arc, + ) -> Arc { + let schema = input.schema(); + let group_by = PhysicalGroupBy::new_single(vec![( + col(group_col, &schema).unwrap(), + group_col.to_string(), + )]); + let sum = AggregateExprBuilder::new(sum_udaf(), vec![col("v", &schema).unwrap()]) + .schema(schema.clone()) + .alias("sum_v") + .build() + .unwrap(); + Arc::new( + AggregateExec::try_new( + mode, + group_by, + vec![Arc::new(sum)], + vec![None], + input, + schema, + ) + .unwrap(), + ) + } + + /// Collects plan output into per-key sums, combining duplicate keys the way a Final + /// aggregate would. + async fn collect_summed(plan: Arc) -> BTreeMap { + let session = SessionContext::new(); + let batches = collect(plan, session.task_ctx()).await.unwrap(); + let mut result = BTreeMap::new(); + for batch in batches { + let keys = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let sums = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + for (k, s) in keys.iter().zip(sums.iter()) { + *result.entry(k.unwrap()).or_insert(0) += s.unwrap(); + } + } + result + } + + fn two_partition_source(schema: &SchemaRef) -> Arc { + // Duplicate keys 2 and 3 across partitions + sorted_source( + schema, + vec![ + vec![make_batch(schema, &[(1, 10), (2, 20), (3, 30)])], + vec![make_batch(schema, &[(2, 21), (3, 31), (4, 40)])], + ], + ) + } + + #[tokio::test(flavor = "multi_thread")] + async fn pushes_sorted_partial_aggregate_below_merge() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + + let merge = rewritten + .as_any() + .downcast_ref::() + .expect("merge must become the root"); + let agg = merge + .input() + .as_any() + .downcast_ref::() + .expect("partial aggregate must move below the merge"); + assert_eq!(*agg.mode(), AggregateMode::Partial); + assert!(matches!(agg.input_order_mode(), InputOrderMode::Sorted)); + assert_eq!( + agg.properties().output_partitioning().partition_count(), + 2, + "aggregate must run per partition" + ); + assert_eq!(rewritten.schema(), original.schema()); + + // Cross-partition duplicate keys combine to the same totals as in the original plan + assert_eq!( + collect_summed(rewritten).await, + collect_summed(original).await + ); + } + + #[test] + fn does_not_push_hash_aggregate() { + let schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("g", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![5, 4])), + Arc::new(Int64Array::from(vec![10, 20])), + ], + ) + .unwrap(); + let source = sorted_source(&schema, vec![vec![batch.clone()], vec![batch]]); + // Grouping by `g` while the input is sorted by `k` makes the aggregate hash-based + let original = sum_aggregate(AggregateMode::Partial, "g", merge_by_k(source)); + assert!(matches!( + original + .as_any() + .downcast_ref::() + .unwrap() + .input_order_mode(), + InputOrderMode::Linear + )); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + #[test] + fn does_not_push_non_partial_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let original = sum_aggregate(AggregateMode::Single, "k", merge_by_k(source)); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + #[test] + fn does_not_push_below_merge_with_fetch() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let merge = Arc::new(merge_by_k(source).as_ref().clone().with_fetch(Some(3))); + let original = sum_aggregate(AggregateMode::Partial, "k", merge); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + #[test] + fn does_not_push_below_single_partition_merge() { + let schema = test_schema(); + let source = sorted_source( + &schema, + vec![vec![make_batch(&schema, &[(1, 10), (2, 20)])]], + ); + let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); + + let rewritten = push_sorted_partial_aggregate_below_merge(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + fn inline(plan: Arc) -> Arc { + let agg = plan.as_any().downcast_ref::().unwrap(); + Arc::new(InlineAggregateExec::try_new_from_aggregate(agg).unwrap()) + } + + fn worker( + input: Arc, + limit_and_reverse: Option<(usize, bool)>, + ) -> Arc { + Arc::new(WorkerExec::new( + input, + 4096, + limit_and_reverse, + None, + WorkerPlanningParams { + worker_partition_count: 1, + }, + )) + } + + /// Worker plan with the partial aggregate below the merge: merge of per-partition partial + /// states can contain duplicate group keys, so the row limit must be limit * partitions + /// while the aggregates take the group limit. + #[test] + fn worker_limit_above_merged_partial_aggregate_limits_groups_per_partition() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let merged = push_sorted_partial_aggregate_below_merge_shape(agg); + let p = worker(merged, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker + .input + .as_any() + .downcast_ref::() + .expect("merge must stay on top of per-partition aggregates"); + assert_eq!(merge.fetch(), None); + let local_limit = merge + .input() + .as_any() + .downcast_ref::() + .expect("the limit must become a per-partition row limit above the aggregate"); + assert_eq!(local_limit.fetch(), 3); + let agg = local_limit + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(agg.limit(), Some(3)); + } + + /// The limit descends below the merge even for an aggregate that is not an + /// InlineAggregateExec: only the early-stop absorption is specific to it, the row limit + /// placement is not. + #[test] + fn worker_limit_above_merged_raw_partial_aggregate_stays_below_merge() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = sum_aggregate(AggregateMode::Partial, "k", source); + let merged = merge_by_k(agg); + let p = worker(merged, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker_node = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker_node + .input + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(merge.fetch(), None); + let local_limit = merge + .input() + .as_any() + .downcast_ref::() + .expect("a row limit above the merge would truncate duplicate group keys"); + assert_eq!(local_limit.fetch(), 3); + assert!(local_limit.input().as_any().is::()); + + let merged = merge_by_k(sum_aggregate( + AggregateMode::Partial, + "k", + two_partition_source(&schema), + )); + let p = worker(merged, Some((3, true))); + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + let worker_node = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker_node + .input + .as_any() + .downcast_ref::() + .unwrap(); + let tail = merge + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(tail.limit, 3); + assert!(tail.input.as_any().is::()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn worker_reverse_limit_above_merged_partial_aggregate_tails_each_partition() { + let schema = test_schema(); + // Keys 4..=6 are present in both partitions; per-partition tails of 3 groups are + // {4, 5, 6} and {7, 8, 9}. + let source = sorted_source( + &schema, + vec![ + vec![make_batch( + &schema, + &(1..=6).map(|k| (k, k * 10)).collect::>(), + )], + vec![make_batch( + &schema, + &(4..=9).map(|k| (k, k * 100 + 1)).collect::>(), + )], + ], + ); + let baseline = collect_summed(sum_aggregate( + AggregateMode::Partial, + "k", + merge_by_k(source.clone()), + )) + .await; + + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let merged = push_sorted_partial_aggregate_below_merge_shape(agg); + let p = worker(merged, Some((3, true))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker = rewritten.as_any().downcast_ref::().unwrap(); + let merge = worker + .input + .as_any() + .downcast_ref::() + .expect("merge must stay on top of per-partition tails"); + assert_eq!(merge.fetch(), None); + let tail = merge + .input() + .as_any() + .downcast_ref::() + .expect("reverse limit must become a per-partition tail below the merge"); + assert_eq!(tail.limit, 3); + let agg = tail + .input + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + agg.limit(), + None, + "tail limit can not stop the aggregate early" + ); + + // The merged stream must carry complete totals for the last 3 group keys; earlier keys + // may arrive partial and are dropped by the router's own limit. + let summed = collect_summed(worker.input.clone()).await; + for key in [7, 8, 9] { + assert_eq!(summed[&key], baseline[&key], "complete total for key {key}"); + } + } + + /// A global aggregate doesn't use the input order: the merge below it becomes a plain + /// coalesce, through the filter, with identical results. + #[tokio::test(flavor = "multi_thread")] + async fn drops_sort_merge_under_global_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let filter: Arc = Arc::new( + FilterExec::try_new( + Arc::new(datafusion::physical_plan::expressions::BinaryExpr::new( + col("k", &schema).unwrap(), + datafusion::logical_expr::Operator::Gt, + Arc::new(datafusion::physical_plan::expressions::Literal::new( + datafusion::scalar::ScalarValue::Int64(Some(1)), + )), + )), + merge_by_k(source), + ) + .unwrap(), + ); + let original = global_sum_aggregate(filter); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + + let agg = rewritten.as_any().downcast_ref::().unwrap(); + let filter = agg + .input() + .as_any() + .downcast_ref::() + .expect("filter must stay in place"); + assert!( + filter.input().as_any().is::(), + "the merge must become a plain coalesce" + ); + + assert_eq!( + collect_global_sum(rewritten).await, + collect_global_sum(original).await + ); + } + + /// An order-sensitive aggregate stays Linear with an empty GROUP BY but still needs its + /// input ordered, so its merge stays. + #[test] + fn keeps_sort_merge_under_order_sensitive_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let merged = merge_by_k(source); + let first = AggregateExprBuilder::new( + datafusion::functions_aggregate::array_agg::array_agg_udaf(), + vec![col("v", &schema).unwrap()], + ) + .schema(schema.clone()) + .alias("vals") + .order_by(LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", &schema).unwrap(), + )])) + .build() + .unwrap(); + let original: Arc = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(Vec::new()), + vec![Arc::new(first)], + vec![None], + merged, + schema, + ) + .unwrap(), + ); + assert!(original + .as_any() + .downcast_ref::() + .unwrap() + .required_input_ordering()[0] + .is_some()); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + /// A grouped hash aggregate over unmerged partitions could build a hash table per + /// partition, so its merge stays. + #[test] + fn keeps_sort_merge_under_grouped_hash_aggregate() { + let schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int64, false), + Field::new("g", DataType::Int64, false), + Field::new("v", DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![5, 4])), + Arc::new(Int64Array::from(vec![10, 20])), + ], + ) + .unwrap(); + let source = sorted_source(&schema, vec![vec![batch.clone()], vec![batch]]); + let original = sum_aggregate(AggregateMode::Partial, "g", merge_by_k(source)); + assert!(matches!( + original + .as_any() + .downcast_ref::() + .unwrap() + .input_order_mode(), + InputOrderMode::Linear + )); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + fn global_sum_aggregate(input: Arc) -> Arc { + let schema = input.schema(); + let sum = AggregateExprBuilder::new(sum_udaf(), vec![col("v", &schema).unwrap()]) + .schema(schema.clone()) + .alias("sum_v") + .build() + .unwrap(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(Vec::new()), + vec![Arc::new(sum)], + vec![None], + input, + schema, + ) + .unwrap(), + ) + } + + /// Sums the partial states of a global aggregate across partitions. + async fn collect_global_sum(plan: Arc) -> i64 { + let session = SessionContext::new(); + let batches = collect(plan, session.task_ctx()).await.unwrap(); + batches + .iter() + .flat_map(|b| { + b.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .flatten() + .collect::>() + }) + .sum() + } + + #[test] + fn keeps_sort_merge_under_sorted_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let original = sum_aggregate(AggregateMode::Partial, "k", merge_by_k(source)); + + let rewritten = drop_sort_merge_under_global_aggregate(original.clone()).unwrap(); + assert!(Arc::ptr_eq(&rewritten, &original)); + } + + /// The descent looks through projections: the row limit is a plain count and commutes with + /// column renames. + #[test] + fn worker_limit_descends_through_projection() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let merged = push_sorted_partial_aggregate_below_merge_shape(agg); + let projection = Arc::new( + ProjectionExec::try_new( + merged + .schema() + .fields() + .iter() + .map(|f| (col(f.name(), &merged.schema()).unwrap(), f.name().clone())) + .collect(), + merged, + ) + .unwrap(), + ); + let p = worker(projection, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker_node = rewritten.as_any().downcast_ref::().unwrap(); + let projection = worker_node + .input + .as_any() + .downcast_ref::() + .expect("projection must stay on top"); + let merge = projection + .input() + .as_any() + .downcast_ref::() + .unwrap(); + let local_limit = merge + .input() + .as_any() + .downcast_ref::() + .expect("the limit must descend through the projection to the aggregate"); + assert_eq!(local_limit.fetch(), 3); + let agg = local_limit + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(agg.limit(), Some(3)); + } + + /// A shape the descent doesn't recognize but containing per-partition partial aggregates + /// must not get a worker limit at all: a row limit could land above a duplicate-bearing + /// merge. + #[test] + fn worker_limit_skipped_for_unrecognized_shape_with_per_partition_aggregate() { + let schema = test_schema(); + let source = two_partition_source(&schema); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let coalesce: Arc = Arc::new(CoalescePartitionsExec::new(agg)); + let p = worker(coalesce, Some((3, false))); + + let rewritten = add_limit_to_workers(p.clone(), &ConfigOptions::default()).unwrap(); + assert!( + Arc::ptr_eq(&rewritten, &p), + "no limit must be added to an unrecognized shape with a per-partition aggregate" + ); + } + + /// Single partition partial aggregate emits unique group keys, the row limit stays exact + /// and also lets the aggregate stop early. + #[test] + fn worker_limit_above_single_partition_partial_aggregate_sets_aggregate_limit() { + let schema = test_schema(); + let source = sorted_source( + &schema, + vec![vec![make_batch(&schema, &[(1, 10), (2, 20)])]], + ); + let agg = inline(sum_aggregate(AggregateMode::Partial, "k", source)); + let p = worker(agg, Some((3, false))); + + let rewritten = add_limit_to_workers(p, &ConfigOptions::default()).unwrap(); + + let worker = rewritten.as_any().downcast_ref::().unwrap(); + let limit = worker + .input + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(limit.fetch(), Some(3)); + let agg = limit + .input() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(agg.limit(), Some(3)); + } + + /// Builds the post-rewrite shape merge-above-aggregate for an already converted + /// InlineAggregateExec. + fn push_sorted_partial_aggregate_below_merge_shape( + agg: Arc, + ) -> Arc { + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("k", &agg.schema()).unwrap(), + )]); + Arc::new(SortPreservingMergeExec::new(ordering, agg)) + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs index e4ee5eb698b3c..ec642715545b5 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs @@ -9,7 +9,8 @@ mod trace_data_loaded; use super::serialized_plan::PreSerializedPlan; use crate::cluster::{Cluster, WorkerPlanningParams}; use crate::queryplanner::optimizations::distributed_partial_aggregate::{ - add_limit_to_workers, ensure_partition_merge, push_aggregate_to_workers, + add_limit_to_workers, drop_sort_merge_under_global_aggregate, ensure_partition_merge, + push_aggregate_to_workers, push_sorted_partial_aggregate_below_merge, replace_suboptimal_merge_sorts, }; use crate::queryplanner::optimizations::inline_aggregate_rewriter::replace_with_inline_aggregate; @@ -145,6 +146,12 @@ fn pre_optimize_physical_plan( // Handles the root node case let p = ensure_partition_merge(p)?; + // Make the merge carry partial aggregate states instead of all raw rows + let p = rewrite_physical_plan(p, &mut |p| push_sorted_partial_aggregate_below_merge(p))?; + + // Global (no GROUP BY) aggregates don't need their input merged in the sort order + let p = rewrite_physical_plan(p, &mut |p| drop_sort_merge_under_global_aggregate(p))?; + // Replace sorted AggregateExec with InlineAggregateExec for better performance let p = rewrite_physical_plan(p, &mut |p| replace_with_inline_aggregate(p))?; diff --git a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs index 4f64a28a45d83..dfe940a101510 100644 --- a/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs +++ b/rust/cubestore/cubestore/src/queryplanner/tail_limit.rs @@ -6,21 +6,21 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::cube_ext; use datafusion::error::DataFusionError; use datafusion::execution::TaskContext; -use datafusion::physical_plan::common::collect; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; use futures::stream::Stream; -use futures::Future; +use futures::{Future, StreamExt}; use pin_project_lite::pin_project; use std::any::Any; +use std::collections::VecDeque; use std::fmt::Formatter; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -///Return n last rows in input +/// Returns the last `limit` rows of each input partition. #[derive(Debug)] pub struct TailLimitExec { pub input: Arc, @@ -58,6 +58,10 @@ impl ExecutionPlan for TailLimitExec { vec![&self.input] } + fn maintains_input_order(&self) -> Vec { + vec![true] + } + fn with_new_children( self: Arc, children: Vec>, @@ -74,19 +78,6 @@ impl ExecutionPlan for TailLimitExec { partition: usize, context: Arc, ) -> Result { - if 0 != partition { - return Err(DataFusionError::Internal(format!( - "TailLimitExec invalid partition {}", - partition - ))); - } - - if 1 != self.input.properties().partitioning.partition_count() { - return Err(DataFusionError::Internal( - "TailLimitExec requires a single input partition".to_owned(), - )); - } - let input = self.input.execute(partition, context)?; Ok(Box::pin(TailLimitStream::new(input, self.limit))) } @@ -109,7 +100,7 @@ impl TailLimitStream { let schema = input.schema(); let task = async move { let schema = input.schema(); - let data = collect(input).await?; + let data = collect_tail_window(input, n).await?; batches_tail(data, n, schema.clone()) }; cube_ext::spawn_oneshot_with_catch_unwind(task, tx); @@ -123,6 +114,40 @@ impl TailLimitStream { } } +/// Collects a sliding tail window of the input: keeps only the trailing batches needed to cover +/// `limit` rows, newer rows displace older ones. Every stored batch is cut to at most `limit` +/// rows on arrival and the eviction keeps the minimal covering suffix, so the window never holds +/// more than 2 * `limit` rows. The front batch may overshoot the window, it is sliced later by +/// [batches_tail]. +async fn collect_tail_window( + mut input: SendableRecordBatchStream, + limit: usize, +) -> Result, DataFusionError> { + let mut window = VecDeque::new(); + let mut total_rows = 0; + while let Some(batch) = input.next().await { + let batch = batch?; + let rows = batch.num_rows(); + if rows >= limit { + // The batch alone covers the whole window + window.clear(); + total_rows = limit; + window.push_back(if rows > limit { + skip_first_rows(&batch, rows - limit) + } else { + batch + }); + continue; + } + total_rows += rows; + window.push_back(batch); + while window.len() > 1 && total_rows - window.front().unwrap().num_rows() >= limit { + total_rows -= window.pop_front().unwrap().num_rows(); + } + } + Ok(window.into()) +} + fn batches_tail( mut batches: Vec, limit: usize, @@ -293,6 +318,86 @@ mod tests { assert!(to_ints(r).into_iter().flatten().collect_vec().is_empty()); } + #[tokio::test] + async fn batches_larger_than_limit() { + // 20-row batch followed by a 3-row batch, limit 5: last 2 rows of the big batch + 3 + let big: Vec = (0..20).collect(); + let input = vec![ints(big), ints(vec![100, 101, 102])]; + let inp = try_make_memory_data_source(&vec![input], ints_schema(), None).unwrap(); + let r = result_collect( + Arc::new(TailLimitExec::new(inp, 5)), + Arc::new(TaskContext::default()), + ) + .await + .unwrap(); + assert_eq!( + to_ints(r).into_iter().flatten().collect_vec(), + vec![18, 19, 100, 101, 102], + ); + } + + /// The window must stay bounded by the limit, not by the largest input batch. + #[tokio::test] + async fn window_stays_bounded() { + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + + // A large batch first, then a trickle of small ones that never covers the limit + let mut batches = vec![ints((0..1000).collect())]; + batches.extend((0..10).map(|i| ints(vec![i]))); + + let stream = Box::pin(RecordBatchStreamAdapter::new( + ints_schema(), + futures::stream::iter(batches.into_iter().map(Ok)), + )); + let window = collect_tail_window(stream, 20).await.unwrap(); + + let window_rows: usize = window.iter().map(|b| b.num_rows()).sum(); + assert!( + window_rows <= 2 * 20, + "window holds {} rows, must stay within 2 * limit", + window_rows + ); + let result = batches_tail(window, 20, ints_schema()).unwrap(); + let last_20: Vec = (990..1000).chain(0..10).collect(); + assert_eq!( + to_ints(vec![result]).into_iter().flatten().collect_vec(), + last_20 + ); + } + + #[tokio::test] + async fn empty_partition() { + let partitions = vec![vec![], vec![ints(vec![1, 2])]]; + let inp = try_make_memory_data_source(&partitions, ints_schema(), None).unwrap(); + let r = result_collect( + Arc::new(TailLimitExec::new(inp, 2)), + Arc::new(TaskContext::default()), + ) + .await + .unwrap(); + assert_eq!(to_ints(r).into_iter().flatten().collect_vec(), vec![1, 2]); + } + + #[tokio::test] + async fn multiple_partitions() { + let partitions = vec![ + vec![ints(vec![1, 2, 3]), ints(vec![4, 5])], + vec![ints(vec![10, 20])], + ]; + let inp = try_make_memory_data_source(&partitions, ints_schema(), None).unwrap(); + let r = result_collect( + Arc::new(TailLimitExec::new(inp, 2)), + Arc::new(TaskContext::default()), + ) + .await + .unwrap(); + // The last 2 rows of each partition + assert_eq!( + to_ints(r).into_iter().flatten().collect_vec(), + vec![4, 5, 10, 20], + ); + } + #[tokio::test] async fn several_batches() { let input = vec![ diff --git a/rust/cubestore/cubestore/src/sql/mod.rs b/rust/cubestore/cubestore/src/sql/mod.rs index 2ab0f11032b9b..50f20eacd4ce6 100644 --- a/rust/cubestore/cubestore/src/sql/mod.rs +++ b/rust/cubestore/cubestore/src/sql/mod.rs @@ -3619,14 +3619,13 @@ mod tests { \n CoalescePartitions\ \n LinearPartialAggregate\ \n Filter\ - \n MergeSort\ - \n Scan, index: default:1:[1]:sort_on[num], fields: *\ - \n FilterByKeyRange\ - \n CheckMemoryExec\ - \n ParquetScan\ - \n FilterByKeyRange\ - \n CheckMemoryExec\ - \n ParquetScan"; + \n Scan, index: default:1:[1]:sort_on[num], fields: *\ + \n FilterByKeyRange\ + \n CheckMemoryExec\ + \n ParquetScan\ + \n FilterByKeyRange\ + \n CheckMemoryExec\ + \n ParquetScan"; let plan = pp_phys_plan_ext(plans.worker.as_ref(), &opts); let p = plan_regexp.replace_all(&plan, "ParquetScan"); println!("pp {}", p);