diff --git a/rust/cubestore/Cargo.lock b/rust/cubestore/Cargo.lock index ab224b78e4006..9990984372e29 100644 --- a/rust/cubestore/Cargo.lock +++ b/rust/cubestore/Cargo.lock @@ -1758,7 +1758,7 @@ checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" [[package]] name = "datafusion" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "arrow-ipc", @@ -1811,7 +1811,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-trait", @@ -1830,7 +1830,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-trait", @@ -1851,7 +1851,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -1874,7 +1874,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "log", "tokio", @@ -1883,7 +1883,7 @@ dependencies = [ [[package]] name = "datafusion-datasource" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-compression 0.4.17", @@ -1916,12 +1916,12 @@ dependencies = [ [[package]] name = "datafusion-doc" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" [[package]] name = "datafusion-execution" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "dashmap", @@ -1941,7 +1941,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "chrono", @@ -1961,7 +1961,7 @@ dependencies = [ [[package]] name = "datafusion-expr-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "datafusion-common", @@ -1973,7 +1973,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "arrow-buffer", @@ -2001,7 +2001,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2021,7 +2021,7 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2033,7 +2033,7 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "arrow-ord", @@ -2053,7 +2053,7 @@ dependencies = [ [[package]] name = "datafusion-functions-table" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "async-trait", @@ -2068,7 +2068,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "datafusion-common", "datafusion-doc", @@ -2084,7 +2084,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2093,7 +2093,7 @@ dependencies = [ [[package]] name = "datafusion-macros" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "datafusion-expr", "quote", @@ -2103,7 +2103,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "chrono", @@ -2121,7 +2121,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2142,7 +2142,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2155,7 +2155,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "datafusion-common", @@ -2173,7 +2173,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "ahash 0.8.11", "arrow", @@ -2205,7 +2205,7 @@ dependencies = [ [[package]] name = "datafusion-proto" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "chrono", @@ -2220,7 +2220,7 @@ dependencies = [ [[package]] name = "datafusion-proto-common" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "datafusion-common", @@ -2230,7 +2230,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "46.0.1" -source = "git+https://github.com/cube-js/arrow-datafusion?branch=cube-46.0.1#dc9015e290adbeaff1da80c9c052219c50312f77" +source = "git+https://github.com/cube-js/arrow-datafusion?branch=cubestore-hash-aggregate-limit#ef839b67f88734804bb2127c7e27b25122b55690" dependencies = [ "arrow", "bigdecimal 0.4.8", diff --git a/rust/cubestore/cubestore-sql-tests/src/tests.rs b/rust/cubestore/cubestore-sql-tests/src/tests.rs index aca2bc52bb58a..138e3a6b1cf1a 100644 --- a/rust/cubestore/cubestore-sql-tests/src/tests.rs +++ b/rust/cubestore/cubestore-sql-tests/src/tests.rs @@ -150,6 +150,8 @@ pub fn sql_tests(prefix: &str) -> Vec<(&'static str, TestFn)> { t("planning_inplace_aggregate", planning_inplace_aggregate), t("planning_hints", planning_hints), t("planning_inplace_aggregate2", planning_inplace_aggregate2), + t("planning_topk_hash_aggregate", planning_topk_hash_aggregate), + t("topk_hash_aggregate_trim", topk_hash_aggregate_trim), t("topk_large_inputs", topk_large_inputs), t("partitioned_index", partitioned_index), t( @@ -386,6 +388,7 @@ lazy_static::lazy_static! { "create_table_with_csv_no_header_and_delimiter", "create_table_with_csv_no_header_and_quotes", "filter_pushdown_unique_key", + "planning_topk_hash_aggregate", ].into_iter().map(ToOwned::to_owned).collect(); } @@ -3162,6 +3165,146 @@ async fn planning_inplace_aggregate(service: Box) -> Result<(), C Ok(()) } +async fn planning_topk_hash_aggregate(service: Box) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data(url text, day int, hits int)") + .await?; + service + .exec_query("CREATE TABLE s.D3(a int, b int, c int, h int)") + .await?; + + // GROUP BY a non-indexed column -> hash (Linear) partial aggregate; ORDER BY the group + // column with a LIMIT -> the worker partial aggregate is replaced by TopKHashAggregate. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 1 LIMIT 10") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + pp.contains("TopKHashAggregate, k: 10, factor: 2,"), + "expected TopKHashAggregate on the worker, got:\n{}", + pp + ); + + // LIMIT + OFFSET -> k = limit + offset. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 1 LIMIT 10 OFFSET 5") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + pp.contains("TopKHashAggregate, k: 15, factor: 2,"), + "expected k=15 (limit+offset), got:\n{}", + pp + ); + + // ORDER BY an aggregate (not a group-by column) -> no trim. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 2 DESC LIMIT 10") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + !pp.contains("TopKHashAggregate"), + "did not expect TopKHashAggregate when ordering by an aggregate, got:\n{}", + pp + ); + + // No LIMIT -> no trim. + let p = service + .plan_query("SELECT day, SUM(hits) FROM s.Data GROUP BY 1 ORDER BY 1") + .await?; + let pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + !pp.contains("TopKHashAggregate"), + "did not expect TopKHashAggregate without a limit, got:\n{}", + pp + ); + + // ORDER BY a proper SUBSET of GROUP BY (b out of b, c). The worker cut and the router sort must + // both use the total order T = [b, c]: the worker trim order carries the tie-break column c, and + // the router's global Sort is extended with c so its top-k matches the global top-k by T. + let p = service + .plan_query("SELECT b, c, SUM(h) FROM s.D3 GROUP BY 1, 2 ORDER BY 1 LIMIT 3") + .await?; + let worker_pp = pp_phys_plan_ext(p.worker.as_ref(), &PPOptions::none()); + assert!( + worker_pp.contains("TopKHashAggregate, k: 3, factor: 2,") + && worker_pp.contains("(0, SortOptions { descending: false, nulls_first: false })") + && worker_pp.contains("(1, SortOptions { descending: false, nulls_first: true })"), + "expected worker trim order [b, c] totalized, got:\n{}", + worker_pp + ); + let router_pp = pp_phys_plan_ext( + p.router.as_ref(), + &PPOptions { + show_sort_by: true, + ..PPOptions::none() + }, + ); + assert!( + router_pp.contains("b@0") && router_pp.contains("c@1"), + "expected router Sort extended with the tie-break column c, got:\n{}", + router_pp + ); + + Ok(()) +} + +async fn topk_hash_aggregate_trim(service: Box) -> Result<(), CubeError> { + service.exec_query("CREATE SCHEMA s").await?; + service + .exec_query("CREATE TABLE s.Data(a int, b int, hits int)") + .await?; + // 12 distinct (a, b) groups, each with two rows so partial aggregation actually groups. + // With k=3 and factor=2 the trim activates (g=12 > 6) but the result must match a full + // top-k. ORDER BY a (a proper subset of GROUP BY a, b) exercises totalization: the worker + // breaks ties on a by b so the router still receives every needed partial state. + service + .exec_query( + "INSERT INTO s.Data(a, b, hits) VALUES \ + (1,1,10),(1,1,5),(1,2,1),(1,2,2),\ + (2,1,7),(2,1,3),(2,2,4),(2,2,6),\ + (3,1,8),(3,1,2),(3,2,9),(3,2,1),\ + (4,1,1),(4,1,1),(4,2,1),(4,2,1),\ + (5,1,1),(5,1,1),(5,2,1),(5,2,1),\ + (6,1,1),(6,1,1),(6,2,1),(6,2,1)", + ) + .await?; + + // ORDER BY a, b LIMIT 3 (ascending): smallest three groups by (a, b). + let r = service + .exec_query("SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 ORDER BY 1, 2 LIMIT 3") + .await?; + assert_eq!(to_rows(&r), rows(&[(1, 1, 15), (1, 2, 3), (2, 1, 10)])); + + // ORDER BY a, b DESC LIMIT 3: largest three groups by (a, b). + let r = service + .exec_query( + "SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 ORDER BY 1 DESC, 2 DESC LIMIT 3", + ) + .await?; + assert_eq!(to_rows(&r), rows(&[(6, 2, 2), (6, 1, 2), (5, 2, 2)])); + + // ORDER BY a only (a proper subset of GROUP BY a, b), LIMIT 2. The selected group SET is + // deterministic (both groups of a=1), but the intra-tie row order is not, so assert as a set. + // Each returned group must carry its complete sum regardless of cross-worker tie-breaking, + // which is what totalization (append b to the cut order) guarantees. + let r = service + .exec_query("SELECT a, b, SUM(hits) FROM s.Data GROUP BY 1, 2 ORDER BY 1 LIMIT 2") + .await?; + let got = to_rows(&r); + assert_eq!(got.len(), 2, "expected 2 rows, got: {:?}", got); + for expected in rows(&[(1, 1, 15), (1, 2, 3)]) { + assert!( + got.contains(&expected), + "missing {:?} in {:?}", + expected, + got + ); + } + + Ok(()) +} + async fn planning_hints(service: Box) -> Result<(), CubeError> { service.exec_query("CREATE SCHEMA s").await?; service diff --git a/rust/cubestore/cubestore/Cargo.toml b/rust/cubestore/cubestore/Cargo.toml index 78c9cf5aa613d..e969c62d00b05 100644 --- a/rust/cubestore/cubestore/Cargo.toml +++ b/rust/cubestore/cubestore/Cargo.toml @@ -28,10 +28,10 @@ cubezetasketch = { path = "../cubezetasketch" } cubedatasketches = { path = "../cubedatasketches" } cubeshared = { path = "../../cube/cubeshared" } cuberpc = { path = "../cuberpc" } -datafusion = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1", features = ["serde"] } -datafusion-datasource = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1" } -datafusion-proto = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1" } -datafusion-proto-common = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cube-46.0.1" } +datafusion = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit", features = ["serde"] } +datafusion-datasource = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit" } +datafusion-proto = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit" } +datafusion-proto-common = { git = "https://github.com/cube-js/arrow-datafusion", branch = "cubestore-hash-aggregate-limit" } csv = "1.1.3" bytes = "1.6.0" serde_json = "1.0.56" diff --git a/rust/cubestore/cubestore/src/config/mod.rs b/rust/cubestore/cubestore/src/config/mod.rs index 8c3c78492af64..aec9565a6535a 100644 --- a/rust/cubestore/cubestore/src/config/mod.rs +++ b/rust/cubestore/cubestore/src/config/mod.rs @@ -496,6 +496,11 @@ pub trait ConfigObj: DIService { fn enable_topk(&self) -> bool; + /// Factor `f` controlling when the worker-side partial hash aggregate trims its output to the + /// top-k groups. Trimming happens only when the number of local groups exceeds `f * k`, where + /// `k = limit + offset`. `0` disables the optimization. + fn partial_hash_aggregate_topk_factor(&self) -> usize; + fn allow_decimal128(&self) -> bool; fn enable_remove_orphaned_remote_files(&self) -> bool; @@ -638,6 +643,7 @@ pub struct ConfigObjImpl { pub max_ingestion_data_frames: usize, pub upload_to_remote: bool, pub enable_topk: bool, + pub partial_hash_aggregate_topk_factor: usize, pub allow_decimal128: bool, pub enable_remove_orphaned_remote_files: bool, pub enable_startup_warmup: bool, @@ -936,6 +942,10 @@ impl ConfigObj for ConfigObjImpl { self.enable_topk } + fn partial_hash_aggregate_topk_factor(&self) -> usize { + self.partial_hash_aggregate_topk_factor + } + fn allow_decimal128(&self) -> bool { self.allow_decimal128 } @@ -1514,6 +1524,10 @@ impl Config { .unwrap_or("localhost".to_string()), upload_to_remote: !env::var("CUBESTORE_NO_UPLOAD").ok().is_some(), enable_topk: env_bool("CUBESTORE_ENABLE_TOPK", true), + partial_hash_aggregate_topk_factor: env_parse( + "CUBESTORE_PARTIAL_HASH_AGGREGATE_TOPK_FACTOR", + 2, + ), allow_decimal128: env_bool("CUBESTORE_ALLOW_DECIMAL128", false), enable_remove_orphaned_remote_files: env_bool( "CUBESTORE_ENABLE_REMOVE_ORPHANED_REMOTE_FILES", @@ -1748,6 +1762,7 @@ impl Config { server_name: "localhost".to_string(), upload_to_remote: true, enable_topk: true, + partial_hash_aggregate_topk_factor: 2, allow_decimal128: false, enable_remove_orphaned_remote_files: false, enable_startup_warmup: true, @@ -2439,6 +2454,7 @@ impl Config { .clone(), i.get_service_typed().await, i.get_service_typed().await, + i.get_service_typed::().await, ) }) .await; diff --git a/rust/cubestore/cubestore/src/queryplanner/mod.rs b/rust/cubestore/cubestore/src/queryplanner/mod.rs index f2a41ac7d3b7c..3be4cc5456f68 100644 --- a/rust/cubestore/cubestore/src/queryplanner/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/mod.rs @@ -17,6 +17,7 @@ pub mod query_executor; pub mod serialized_plan; mod tail_limit; mod topk; +mod topk_aggregate; pub mod trace_data_loaded; use serialized_plan::PreSerializedPlan; pub use topk::MIN_TOPK_STREAM_ROWS; diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs index e4ee5eb698b3c..5aa67c922c354 100644 --- a/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/mod.rs @@ -4,6 +4,7 @@ mod inline_aggregate_rewriter; pub mod is_not_distinct_from_join_keys; pub mod rewrite_plan; pub mod rolling_optimizer; +mod topk_aggregate_rewriter; mod trace_data_loaded; use super::serialized_plan::PreSerializedPlan; @@ -13,6 +14,7 @@ use crate::queryplanner::optimizations::distributed_partial_aggregate::{ replace_suboptimal_merge_sorts, }; use crate::queryplanner::optimizations::inline_aggregate_rewriter::replace_with_inline_aggregate; +use crate::queryplanner::optimizations::topk_aggregate_rewriter::replace_with_topk_aggregate; use crate::queryplanner::planning::CubeExtensionPlanner; use crate::queryplanner::pretty_printers::{pp_phys_plan_ext, PPOptions}; use crate::queryplanner::rolling::RollingWindowPlanner; @@ -109,11 +111,15 @@ impl QueryPlanner for CubeQueryPlanner { } #[derive(Debug)] -pub struct PreOptimizeRule {} +pub struct PreOptimizeRule { + partial_hash_aggregate_topk_factor: usize, +} impl PreOptimizeRule { - pub fn new() -> Self { - Self {} + pub fn new(partial_hash_aggregate_topk_factor: usize) -> Self { + Self { + partial_hash_aggregate_topk_factor, + } } } @@ -123,7 +129,7 @@ impl PhysicalOptimizerRule for PreOptimizeRule { plan: Arc, _config: &ConfigOptions, ) -> datafusion::common::Result> { - pre_optimize_physical_plan(plan) + pre_optimize_physical_plan(plan, self.partial_hash_aggregate_topk_factor) } fn name(&self) -> &str { @@ -137,6 +143,7 @@ impl PhysicalOptimizerRule for PreOptimizeRule { fn pre_optimize_physical_plan( p: Arc, + partial_hash_aggregate_topk_factor: usize, ) -> Result, DataFusionError> { let p = rewrite_physical_plan(p, &mut |p| push_aggregate_to_workers(p))?; @@ -148,6 +155,11 @@ fn pre_optimize_physical_plan( // Replace sorted AggregateExec with InlineAggregateExec for better performance let p = rewrite_physical_plan(p, &mut |p| replace_with_inline_aggregate(p))?; + // Trim the worker-side partial hash aggregate to the top-k groups when the query orders by a + // subset of group-by columns and has a limit. Runs after inline-aggregate replacement so it + // only sees the remaining (hash) partial aggregates. + let p = replace_with_topk_aggregate(p, partial_hash_aggregate_topk_factor)?; + Ok(p) } diff --git a/rust/cubestore/cubestore/src/queryplanner/optimizations/topk_aggregate_rewriter.rs b/rust/cubestore/cubestore/src/queryplanner/optimizations/topk_aggregate_rewriter.rs new file mode 100644 index 0000000000000..e8a6d7a39d92f --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/optimizations/topk_aggregate_rewriter.rs @@ -0,0 +1,251 @@ +use crate::queryplanner::planning::WorkerExec; +use crate::queryplanner::query_executor::ClusterSendExec; +use crate::queryplanner::topk_aggregate::TopKHashAggregateExec; +use datafusion::arrow::compute::SortOptions; +use datafusion::error::DataFusionError; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode}; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::expressions::Column; +use datafusion::physical_plan::limit::GlobalLimitExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::{ExecutionPlan, InputOrderMode}; +use std::sync::Arc; + +/// Trim the worker-side partial hash aggregate to the top-k groups when the plan is +/// `LIMIT k` over `ORDER BY ` over a distributed hash aggregate. +/// +/// Correctness requires a TOTAL order over groups (`T = ORDER BY ++ remaining group-by columns`, +/// in group-by order) applied in TWO places that must agree: +/// - the worker cut: each worker keeps its local top-k by `T`; +/// - the router select: the global Sort + Limit must also order by `T`. +/// Under `T` the router's top-k equals the global top-k by `T`, and every worker that holds a +/// partial state for such a group keeps it (its local rank can only be smaller), so every needed +/// partial state reaches the router. Ordering the router by `T` instead of the bare `ORDER BY` does +/// not change the query contract: `ORDER BY` is a prefix of `T`, so the output stays validly +/// ordered and the previously-unspecified tie order just becomes deterministic. +/// +/// We only rewrite when the plan matches exactly `Sort(/Limit) -> [passthrough] -> Final aggregate +/// -> [passthrough/cluster boundary] -> Partial hash aggregate`; anything else on the path (a +/// HAVING filter, a nested aggregate, a computed projection) makes us bail, so we never trim a plan +/// where the limit does not directly govern this aggregate. +/// +/// `factor` gates trimming at runtime (only when local groups exceed `factor * k`); `0` disables. +pub fn replace_with_topk_aggregate( + plan: Arc, + factor: usize, +) -> Result, DataFusionError> { + if factor == 0 { + return Ok(plan); + } + let Some(target) = analyze(&plan) else { + return Ok(plan); + }; + apply(plan, &target, factor) +} + +struct Target { + /// The router `SortExec` whose ordering must be extended to the total order. + sort: Arc, + /// The worker-side partial hash `AggregateExec` to replace with a trimming exec. + partial: Arc, + /// Tail of the total order to append to the router sort (over the sort's input schema). + router_tail: Vec, + /// Full total order over the partial output schema for the worker cut. + trim_order: Vec<(usize, SortOptions)>, + /// `k = limit + offset`. + k: usize, +} + +fn analyze(root: &Arc) -> Option { + // Peel an optional top GlobalLimit (carries the offset), then require a SortExec. + let (skip, extra_fetch, sort_node) = + if let Some(gl) = root.as_any().downcast_ref::() { + (gl.skip(), gl.fetch(), child(root)?) + } else { + (0, None, root.clone()) + }; + let sort = sort_node.as_any().downcast_ref::()?; + let order: Vec = sort.expr().iter().cloned().collect(); + if order.is_empty() { + return None; + } + // The worker must keep enough groups to cover `limit + offset`. When a top GlobalLimit carries + // the offset, DataFusion already folds `skip + limit` into the sort's fetch, so prefer it; + // otherwise fall back to the GlobalLimit's own `skip + fetch`. + let k = sort + .fetch() + .or_else(|| extra_fetch.map(|fetch| skip + fetch))?; + + // Sort -> [passthrough] -> Final aggregate. + let final_agg_node = descend_to_final_aggregate(sort.input().clone())?; + let final_agg = final_agg_node.as_any().downcast_ref::()?; + + // Final aggregate -> [passthrough/boundary] -> Partial hash aggregate. + let partial_node = descend_to_worker_partial(final_agg.input().clone())?; + let partial = partial_node.as_any().downcast_ref::()?; + if !partial.group_expr().is_single() + || matches!(partial.input_order_mode(), InputOrderMode::Sorted) + { + return None; + } + + let num_group_cols = partial.group_expr().output_exprs().len(); + if num_group_cols == 0 { + return None; + } + let partial_schema = partial.schema(); + let group_names: Vec = partial_schema + .fields() + .iter() + .take(num_group_cols) + .map(|f| f.name().clone()) + .collect(); + + // Map ORDER BY columns onto group-by columns (by name; robust to projections). + let mut used = vec![false; num_group_cols]; + let mut trim_order: Vec<(usize, SortOptions)> = Vec::with_capacity(num_group_cols); + for e in &order { + let column = e.expr.as_any().downcast_ref::()?; + let idx = group_names.iter().position(|n| n == column.name())?; + if used[idx] { + continue; + } + used[idx] = true; + trim_order.push((idx, e.options)); + } + if trim_order.is_empty() { + return None; + } + + // Totalize: append the remaining group-by columns in group-by order. Build the matching tail + // for the router sort over its own (Final-output) schema, resolved by name. + let sort_input_schema = sort.input().schema(); + let mut router_tail: Vec = Vec::new(); + for (idx, is_used) in used.into_iter().enumerate() { + if is_used { + continue; + } + let name = &group_names[idx]; + let options = SortOptions::default(); + let sort_col_idx = sort_input_schema.index_of(name).ok()?; + router_tail.push(PhysicalSortExpr { + expr: Arc::new(Column::new(name, sort_col_idx)), + options, + }); + trim_order.push((idx, options)); + } + + Some(Target { + sort: sort_node, + partial: partial_node, + router_tail, + trim_order, + k, + }) +} + +fn apply( + node: Arc, + target: &Target, + factor: usize, +) -> Result, DataFusionError> { + let is_sort = Arc::ptr_eq(&node, &target.sort); + let is_partial = Arc::ptr_eq(&node, &target.partial); + + let new_children = node + .children() + .into_iter() + .map(|c| apply(c.clone(), target, factor)) + .collect::, _>>()?; + let node = node.with_new_children(new_children)?; + + if is_partial { + if let Some(agg) = node.as_any().downcast_ref::() { + if let Some(exec) = TopKHashAggregateExec::try_new_from_partial( + agg, + target.k, + factor, + target.trim_order.clone(), + ) { + return Ok(Arc::new(exec)); + } + } + // Leaving the full aggregate in place stays correct; the router still sorts by the total + // order, it just receives every group instead of the trimmed top-k. + return Ok(node); + } + + if is_sort { + if let Some(sort) = node.as_any().downcast_ref::() { + let mut exprs: Vec = sort.expr().iter().cloned().collect(); + exprs.extend(target.router_tail.iter().cloned()); + let new_sort = SortExec::new(LexOrdering::new(exprs), sort.input().clone()) + .with_preserve_partitioning(sort.preserve_partitioning()) + .with_fetch(sort.fetch()); + return Ok(Arc::new(new_sort)); + } + } + + Ok(node) +} + +/// Walk down single-child passthrough nodes (which preserve rows and grouping) until the first +/// `Final`/`FinalPartitioned` `AggregateExec`. Returns `None` if a non-passthrough node is hit +/// first (e.g. a filter or a computed projection). +fn descend_to_final_aggregate(mut node: Arc) -> Option> { + loop { + if let Some(agg) = node.as_any().downcast_ref::() { + return matches!( + agg.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) + .then_some(node.clone()); + } + if is_row_passthrough(&node) { + node = child(&node)?; + } else { + return None; + } + } +} + +/// Walk down passthrough nodes from a `Final` aggregate's input to the worker-side `Partial` +/// aggregate, requiring that exactly one `ClusterSend`/`Worker` boundary is crossed. Returns `None` +/// if anything unexpected (a second aggregate, a filter, ...) is on the path. +fn descend_to_worker_partial(mut node: Arc) -> Option> { + let mut crossed_boundary = false; + loop { + if let Some(agg) = node.as_any().downcast_ref::() { + return (crossed_boundary && *agg.mode() == AggregateMode::Partial) + .then_some(node.clone()); + } + if node.as_any().is::() || node.as_any().is::() { + crossed_boundary = true; + node = child(&node)?; + } else if is_row_passthrough(&node) { + node = child(&node)?; + } else { + return None; + } + } +} + +/// Single-child nodes that pass rows through unchanged (preserving grouping), so a limit/sort above +/// them governs the aggregate below them. +fn is_row_passthrough(node: &Arc) -> bool { + let any = node.as_any(); + any.is::() + || any.is::() + || any.is::() + || any.is::() +} + +fn child(node: &Arc) -> Option> { + let children = node.children(); + if children.len() != 1 { + return None; + } + Some(children[0].clone()) +} diff --git a/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs b/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs index 63d386add4951..d33355d113496 100644 --- a/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs +++ b/rust/cubestore/cubestore/src/queryplanner/pretty_printers.rs @@ -43,6 +43,7 @@ use crate::queryplanner::topk::SortColumn; use crate::queryplanner::topk::{ AggregateTopKExec, ClusterAggregateTopKLower, ClusterAggregateTopKUpper, }; +use crate::queryplanner::topk_aggregate::TopKHashAggregateExec; use crate::queryplanner::{CubeTableLogical, InfoSchemaTableProvider, QueryPlan}; //use crate::streaming::topic_table_provider::TopicTableProvider; use datafusion::physical_plan::empty::EmptyExec; @@ -617,6 +618,16 @@ fn pp_phys_plan_indented(p: &dyn ExecutionPlan, indent: usize, o: &PPOptions, ou if let Some(limit) = agg.limit() { *out += &format!(", limit: {}", limit) } + } else if let Some(agg) = a.downcast_ref::() { + *out += &format!( + "TopKHashAggregate, k: {}, factor: {}, order: {:?}", + agg.k(), + agg.factor(), + agg.order() + ); + if o.show_aggregations { + *out += &format!(", aggs: {:?}", agg.aggr_expr()) + } } else if let Some(l) = a.downcast_ref::() { *out += &format!("LocalLimit, n: {}", l.fetch()); } else if let Some(l) = a.downcast_ref::() { diff --git a/rust/cubestore/cubestore/src/queryplanner/query_executor.rs b/rust/cubestore/cubestore/src/queryplanner/query_executor.rs index 5943df1be5e0e..e769a9d937590 100644 --- a/rust/cubestore/cubestore/src/queryplanner/query_executor.rs +++ b/rust/cubestore/cubestore/src/queryplanner/query_executor.rs @@ -147,6 +147,7 @@ pub struct QueryExecutorImpl { metadata_cache_factory: Arc, parquet_metadata_cache: Arc, memory_handler: Arc, + config: Arc, } crate::di_service!(QueryExecutorImpl, [QueryExecutor]); @@ -430,11 +431,13 @@ impl QueryExecutorImpl { metadata_cache_factory: Arc, parquet_metadata_cache: Arc, memory_handler: Arc, + config: Arc, ) -> Arc { Arc::new(QueryExecutorImpl { metadata_cache_factory, parquet_metadata_cache, memory_handler, + config, }) } @@ -483,7 +486,9 @@ impl QueryExecutorImpl { fn physical_optimizer_rules(&self) -> Vec> { vec![ // Cube rules - Arc::new(PreOptimizeRule::new()), + Arc::new(PreOptimizeRule::new( + self.config.partial_hash_aggregate_topk_factor(), + )), // DF rules without EnforceDistribution. We do need to keep EnforceSorting. Arc::new(OutputRequirements::new_add_mode()), Arc::new(AggregateStatistics::new()), diff --git a/rust/cubestore/cubestore/src/queryplanner/topk_aggregate/mod.rs b/rust/cubestore/cubestore/src/queryplanner/topk_aggregate/mod.rs new file mode 100644 index 0000000000000..dfe834fdc0d72 --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/topk_aggregate/mod.rs @@ -0,0 +1,203 @@ +mod topk_hash_aggregate_stream; + +use datafusion::arrow::compute::SortOptions; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::common::stats::Precision; +use datafusion::common::Statistics; +use datafusion::error::Result as DFResult; +use datafusion::execution::TaskContext; +use datafusion::physical_expr::aggregate::AggregateFunctionExpr; +use datafusion::physical_expr::{Distribution, LexRequirement}; +use datafusion::physical_plan::execution_plan::CardinalityEffect; +use datafusion::physical_plan::metrics::MetricsSet; +use datafusion::physical_plan::{aggregates::*, InputOrderMode}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PhysicalExpr, PlanProperties, + SendableRecordBatchStream, +}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +/// Worker-side partial hash aggregate that trims its output to the top-k groups by a total order, +/// so far fewer partial-state rows cross the network to the router's Final aggregate. +/// +/// This is a custom copy of DataFusion's partial hash aggregate (it reuses DF's `GroupValues` and +/// `GroupsAccumulator` building blocks but owns the consume/emit loop), so the only change required +/// in the DataFusion fork is making `new_group_values` public. The aggregation builds the whole +/// group table and, at the single final emit, keeps only the `k` smallest groups by `order` when +/// the number of groups exceeds `factor * k`; otherwise it emits all groups unchanged. +/// +/// `order` is a TOTAL order over groups (ORDER BY columns followed by the remaining group-by +/// columns), expressed as `(partial-output column index, sort options)`. A total order is required +/// for correctness: the same group key can live on multiple workers, and a consistent cut across +/// workers guarantees every partial state the router selects reaches it. +#[derive(Debug, Clone)] +pub struct TopKHashAggregateExec { + group_by: PhysicalGroupBy, + aggr_expr: Vec>, + filter_expr: Vec>>, + pub input: Arc, + /// Partial-aggregate output schema (group columns followed by accumulator state columns). + schema: SchemaRef, + input_schema: SchemaRef, + cache: PlanProperties, + /// Fetch count, `k = limit + offset`. + k: usize, + /// Only trim when the number of local groups exceeds `factor * k`. + factor: usize, + /// Total order over the partial output columns. + order: Vec<(usize, SortOptions)>, +} + +impl TopKHashAggregateExec { + /// Build a `TopKHashAggregateExec` from a partial hash `AggregateExec`, or `None` if it is not a + /// single-group-by partial aggregate (grouping sets and non-partial modes are not supported). + pub fn try_new_from_partial( + aggregate: &AggregateExec, + k: usize, + factor: usize, + order: Vec<(usize, SortOptions)>, + ) -> Option { + if *aggregate.mode() != AggregateMode::Partial { + return None; + } + // Sorted-prefix aggregates are handled by InlineAggregateExec; this targets the hash path. + if matches!(aggregate.input_order_mode(), InputOrderMode::Sorted) { + return None; + } + let group_by = aggregate.group_expr().clone(); + if !group_by.is_single() { + return None; + } + Some(Self { + group_by, + aggr_expr: aggregate.aggr_expr().to_vec(), + filter_expr: aggregate.filter_expr().to_vec(), + input: aggregate.input().clone(), + schema: aggregate.schema().clone(), + input_schema: aggregate.input_schema().clone(), + cache: aggregate.cache().clone(), + k, + factor, + order, + }) + } + + pub fn k(&self) -> usize { + self.k + } + + pub fn factor(&self) -> usize { + self.factor + } + + pub fn order(&self) -> &[(usize, SortOptions)] { + &self.order + } + + pub fn aggr_expr(&self) -> &[Arc] { + &self.aggr_expr + } + + pub fn filter_expr(&self) -> &[Option>] { + &self.filter_expr + } + + pub fn input(&self) -> &Arc { + &self.input + } + + pub fn group_expr(&self) -> &PhysicalGroupBy { + &self.group_by + } +} + +impl DisplayAs for TopKHashAggregateExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "TopKHashAggregateExec: k={}, factor={}, order={:?}", + self.k, self.factor, self.order + )?; + } + } + Ok(()) + } +} + +impl ExecutionPlan for TopKHashAggregateExec { + fn name(&self) -> &'static str { + "TopKHashAggregateExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution] + } + + fn required_input_ordering(&self) -> Vec> { + vec![None] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(Self { + group_by: self.group_by.clone(), + aggr_expr: self.aggr_expr.clone(), + filter_expr: self.filter_expr.clone(), + input: children[0].clone(), + schema: self.schema.clone(), + input_schema: self.input_schema.clone(), + cache: self.cache.clone(), + k: self.k, + factor: self.factor, + order: self.order.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + let stream = + topk_hash_aggregate_stream::TopKHashAggregateStream::new(self, context, partition)?; + Ok(Box::pin(stream)) + } + + fn metrics(&self) -> Option { + None + } + + fn statistics(&self) -> DFResult { + Ok(Statistics { + num_rows: Precision::Absent, + column_statistics: Statistics::unknown_column(&self.schema), + total_byte_size: Precision::Absent, + }) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } +} diff --git a/rust/cubestore/cubestore/src/queryplanner/topk_aggregate/topk_hash_aggregate_stream.rs b/rust/cubestore/cubestore/src/queryplanner/topk_aggregate/topk_hash_aggregate_stream.rs new file mode 100644 index 0000000000000..fe62cb2eb5886 --- /dev/null +++ b/rust/cubestore/cubestore/src/queryplanner/topk_aggregate/topk_hash_aggregate_stream.rs @@ -0,0 +1,284 @@ +use datafusion::arrow::array::{ArrayRef, AsArray, RecordBatch}; +use datafusion::arrow::compute::{lexsort_to_indices, take, SortColumn, SortOptions}; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::dfschema::internal_err; +use datafusion::error::Result as DFResult; +use datafusion::execution::{RecordBatchStream, TaskContext}; +use datafusion::logical_expr::{EmitTo, GroupsAccumulator}; +use datafusion::physical_expr::GroupsAccumulatorAdapter; +use datafusion::physical_plan::aggregates::group_values::{new_group_values, GroupValues}; +use datafusion::physical_plan::aggregates::order::GroupOrdering; +use datafusion::physical_plan::aggregates::PhysicalGroupBy; +use datafusion::physical_plan::udaf::AggregateFunctionExpr; +use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr, SendableRecordBatchStream}; +use futures::ready; +use futures::stream::{Stream, StreamExt}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::TopKHashAggregateExec; + +enum ExecutionState { + ReadingInput, + ProducingOutput(RecordBatch), + Done, +} + +pub(crate) struct TopKHashAggregateStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + aggregate_arguments: Vec>>, + filter_expressions: Vec>>, + group_by: PhysicalGroupBy, + batch_size: usize, + exec_state: ExecutionState, + input_done: bool, + accumulators: Vec>, + group_values: Box, + current_group_indices: Vec, + k: usize, + factor: usize, + order: Vec<(usize, SortOptions)>, +} + +impl TopKHashAggregateStream { + pub fn new( + agg: &TopKHashAggregateExec, + context: Arc, + partition: usize, + ) -> DFResult { + let agg_schema = Arc::clone(&agg.schema()); + let agg_group_by = agg.group_expr().clone(); + let agg_filter_expr = agg.filter_expr().to_vec(); + + let batch_size = context.session_config().batch_size(); + let input = agg.input().execute(partition, Arc::clone(&context))?; + + let aggregate_arguments = + aggregate_expressions(agg.aggr_expr(), agg_group_by.num_group_exprs())?; + + let accumulators: Vec<_> = agg + .aggr_expr() + .iter() + .map(create_group_accumulator) + .collect::>()?; + + let group_schema = agg_group_by.group_schema(&agg.input().schema())?; + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + + Ok(TopKHashAggregateStream { + schema: agg_schema, + input, + aggregate_arguments, + filter_expressions: agg_filter_expr, + group_by: agg_group_by, + batch_size, + exec_state: ExecutionState::ReadingInput, + input_done: false, + accumulators, + group_values, + current_group_indices: Vec::with_capacity(batch_size), + k: agg.k(), + factor: agg.factor(), + order: agg.order().to_vec(), + }) + } +} + +impl Stream for TopKHashAggregateStream { + type Item = DFResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match &self.exec_state { + ExecutionState::ReadingInput => match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if let Err(e) = self.group_aggregate_batch(batch) { + return Poll::Ready(Some(Err(e))); + } + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + // Input exhausted: emit the whole group table at once, then trim to top-k. + None => { + self.input_done = true; + match self.emit_all_trimmed() { + Ok(Some(batch)) => { + self.exec_state = ExecutionState::ProducingOutput(batch) + } + Ok(None) => self.exec_state = ExecutionState::Done, + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + }, + + ExecutionState::ProducingOutput(batch) => { + let batch = batch.clone(); + let size = self.batch_size; + let (next_state, output) = if batch.num_rows() <= size { + (ExecutionState::Done, batch) + } else { + let remaining = batch.slice(size, batch.num_rows() - size); + let output = batch.slice(0, size); + (ExecutionState::ProducingOutput(remaining), output) + }; + self.exec_state = next_state; + return Poll::Ready(Some(Ok(output))); + } + + ExecutionState::Done => return Poll::Ready(None), + } + } + } +} + +impl RecordBatchStream for TopKHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl TopKHashAggregateStream { + fn group_aggregate_batch(&mut self, batch: RecordBatch) -> DFResult<()> { + let group_by_values = evaluate_group_by(&self.group_by, &batch)?; + let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; + let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; + + assert_eq!(group_by_values.len(), 1, "Exactly 1 group value required"); + self.group_values + .intern(&group_by_values[0], &mut self.current_group_indices)?; + let group_indices = &self.current_group_indices; + let total_num_groups = self.group_values.len(); + + for ((acc, values), opt_filter) in self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()) + { + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + acc.update_batch(values, group_indices, opt_filter, total_num_groups)?; + } + Ok(()) + } + + /// Build the partial-state batch for all groups, then keep only the `k` smallest by the total + /// order when the number of groups exceeds `factor * k`. + fn emit_all_trimmed(&mut self) -> DFResult> { + if self.group_values.is_empty() { + return Ok(None); + } + let mut columns = self.group_values.emit(EmitTo::All)?; + for acc in &mut self.accumulators { + columns.extend(acc.state(EmitTo::All)?); + } + let batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + Ok(Some(self.trim_top_k(batch)?)) + } + + fn trim_top_k(&self, batch: RecordBatch) -> DFResult { + let g = batch.num_rows(); + if self.k == 0 || g <= self.factor.saturating_mul(self.k) { + return Ok(batch); + } + let sort_columns: Vec = self + .order + .iter() + .map(|(idx, options)| SortColumn { + values: Arc::clone(batch.column(*idx)), + options: Some(*options), + }) + .collect(); + let indices = lexsort_to_indices(&sort_columns, Some(self.k))?; + let columns = batch + .columns() + .iter() + .map(|c| take(c.as_ref(), &indices, None)) + .collect::, _>>()?; + Ok(RecordBatch::try_new(batch.schema(), columns)?) + } +} + +/// Partial-aggregate argument expressions, one vec per aggregate. Mirrors DataFusion's private +/// `aggregate_expressions` for `AggregateMode::Partial`. +fn aggregate_expressions( + aggr_expr: &[Arc], + _col_idx_base: usize, +) -> DFResult>>> { + Ok(aggr_expr + .iter() + .map(|agg| { + let mut result = agg.expressions(); + if let Some(ordering_req) = agg.order_bys() { + result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); + } + result + }) + .collect()) +} + +fn create_group_accumulator( + agg_expr: &Arc, +) -> DFResult> { + if agg_expr.groups_accumulator_supported() { + agg_expr.create_groups_accumulator() + } else { + let agg_expr_captured = Arc::clone(agg_expr); + let factory = move || agg_expr_captured.create_accumulator(); + Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) + } +} + +fn evaluate(expr: &[Arc], batch: &RecordBatch) -> DFResult> { + expr.iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect() +} + +fn evaluate_many( + expr: &[Vec>], + batch: &RecordBatch, +) -> DFResult>> { + expr.iter().map(|expr| evaluate(expr, batch)).collect() +} + +fn evaluate_optional( + expr: &[Option>], + batch: &RecordBatch, +) -> DFResult>> { + expr.iter() + .map(|expr| { + expr.as_ref() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .transpose() + }) + .collect() +} + +fn evaluate_group_by( + group_by: &PhysicalGroupBy, + batch: &RecordBatch, +) -> DFResult>> { + let exprs: Vec = group_by + .expr() + .iter() + .map(|(expr, _)| { + let value = expr.evaluate(batch)?; + value.into_array(batch.num_rows()) + }) + .collect::>>()?; + + if !group_by.is_single() { + return internal_err!("TopKHashAggregate does not support grouping sets"); + } + + Ok(vec![exprs]) +}