From 0c8dc78b1b96595b52250ade8d79feb3d0f06d02 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 31 Mar 2026 12:35:59 -0400 Subject: [PATCH] Add missing aggregate functions: grouping, percentile_cont, var_population Expose upstream DataFusion aggregate functions that were not yet available in the Python API. Closes #1454. - grouping: returns grouping set membership indicator (rewritten by the ResolveGroupingFunction analyzer rule before physical planning) - percentile_cont: computes exact percentile using continuous interpolation (unlike approx_percentile_cont which uses t-digest) - var_population: alias for var_pop Co-Authored-By: Claude Opus 4.6 (1M context) --- crates/core/src/functions.rs | 23 ++++++++-- python/datafusion/functions.py | 77 ++++++++++++++++++++++++++++++++++ python/tests/test_functions.py | 46 ++++++++++++++++++++ 3 files changed, 142 insertions(+), 4 deletions(-) diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index c32134054..908c1a8dd 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -696,9 +696,10 @@ aggregate_function!(var_pop); aggregate_function!(approx_distinct); aggregate_function!(approx_median); -// Code is commented out since grouping is not yet implemented -// https://github.com/apache/datafusion-python/issues/861 -// aggregate_function!(grouping); +// The grouping function's physical plan is not implemented, but the +// ResolveGroupingFunction analyzer rule rewrites it before the physical +// planner sees it, so it works correctly at runtime. +aggregate_function!(grouping); #[pyfunction] #[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))] @@ -736,6 +737,19 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } +#[pyfunction] +#[pyo3(signature = (sort_expression, percentile, filter=None))] +pub fn percentile_cont( + sort_expression: PySortExpr, + percentile: f64, + filter: Option, +) -> PyDataFusionResult { + let agg_fn = + functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile)); + + add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) +} + // We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] @@ -936,6 +950,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(approx_median))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?; + m.add_wrapped(wrap_pyfunction!(percentile_cont))?; m.add_wrapped(wrap_pyfunction!(range))?; m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; @@ -981,7 +996,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(floor))?; m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; - // m.add_wrapped(wrap_pyfunction!(grouping))?; + m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index f062cbfce..66fa0776d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -149,6 +149,7 @@ "floor", "from_unixtime", "gcd", + "grouping", "in_list", "initcap", "isnan", @@ -216,6 +217,7 @@ "order_by", "overlay", "percent_rank", + "percentile_cont", "pi", "pow", "power", @@ -286,6 +288,7 @@ "uuid", "var", "var_pop", + "var_population", "var_samp", "var_sample", "when", @@ -3523,6 +3526,47 @@ def approx_percentile_cont_with_weight( ) +def percentile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + Unlike :py:func:`approx_percentile_cont`, this function computes the exact + percentile value rather than an approximation. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + sort_expression: Values for which to find the percentile + percentile: This must be between 0.0 and 1.0, inclusive + filter: If provided, only compute against rows for which the filter is True + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5 + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.0 + + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5, + ... filter=dfn.col("a") > dfn.lit(1.0), + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.5 + """ + sort_expr_raw = sort_or_default(sort_expression) + filter_raw = filter.expr if filter is not None else None + return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw)) + + def array_agg( expression: Expr, distinct: bool = False, @@ -3581,6 +3625,30 @@ def array_agg( ) +def grouping( + expression: Expr, + distinct: bool | None = None, + filter: Expr | None = None, +) -> Expr: + """Returns 1 if the data is aggregated across the specified column, or 0 otherwise. + + This function is used with ``GROUPING SETS``, ``CUBE``, or ``ROLLUP`` to + distinguish between aggregated and non-aggregated rows. In a regular + ``GROUP BY`` without grouping sets, it always returns 0. + + Note: The ``grouping`` aggregate function is rewritten by the query + optimizer before execution, so it works correctly even though its + physical plan is not directly implemented. + + Args: + expression: The column to check grouping status for + distinct: If True, compute on distinct values only + filter: If provided, only compute against rows for which the filter is True + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw)) + + def avg( expression: Expr, filter: Expr | None = None, @@ -4052,6 +4120,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr: return Expr(f.var_pop(expression.expr, filter=filter_raw)) +def var_population(expression: Expr, filter: Expr | None = None) -> Expr: + """Computes the population variance of the argument. + + See Also: + This is an alias for :py:func:`var_pop`. + """ + return var_pop(expression, filter) + + def var_samp(expression: Expr, filter: Expr | None = None) -> Expr: """Computes the sample variance of the argument. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 37d349c58..3e153a219 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1435,3 +1435,49 @@ def test_coalesce(df): assert result.column(0) == pa.array( ["Hello", "fallback", "!"], type=pa.string_view() ) + + +def test_percentile_cont(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + result = df.aggregate( + [], [f.percentile_cont(column("a"), 0.5).alias("v")] + ).collect()[0] + assert result.column(0)[0].as_py() == 3.0 + + +def test_percentile_cont_with_filter(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + result = df.aggregate( + [], + [ + f.percentile_cont( + column("a"), 0.5, filter=column("a") > literal(1.0) + ).alias("v") + ], + ).collect()[0] + assert result.column(0)[0].as_py() == 3.5 + + +def test_grouping(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + # In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows. + # Note: grouping() must not be aliased directly in the aggregate expression list + # due to an upstream DataFusion analyzer limitation (the ResolveGroupingFunction + # rule doesn't unwrap Alias nodes). Apply aliases via a follow-up select instead. + result = df.aggregate( + [column("a")], [f.grouping(column("a")), f.sum(column("b")).alias("s")] + ).collect() + grouping_col = pa.concat_arrays([batch.column(1) for batch in result]).to_pylist() + assert all(v == 0 for v in grouping_col) + + +def test_var_population(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]}) + result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0] + # var_population is an alias for var_pop + expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0] + assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10