Skip to content

Commit 0c8dc78

Browse files
timsaucerclaude
andcommitted
Add missing aggregate functions: grouping, percentile_cont, var_population
Expose upstream DataFusion aggregate functions that were not yet available in the Python API. Closes #1454. - grouping: returns grouping set membership indicator (rewritten by the ResolveGroupingFunction analyzer rule before physical planning) - percentile_cont: computes exact percentile using continuous interpolation (unlike approx_percentile_cont which uses t-digest) - var_population: alias for var_pop Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 73a9d53 commit 0c8dc78

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

crates/core/src/functions.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,10 @@ aggregate_function!(var_pop);
696696
aggregate_function!(approx_distinct);
697697
aggregate_function!(approx_median);
698698

699-
// Code is commented out since grouping is not yet implemented
700-
// https://github.com/apache/datafusion-python/issues/861
701-
// aggregate_function!(grouping);
699+
// The grouping function's physical plan is not implemented, but the
700+
// ResolveGroupingFunction analyzer rule rewrites it before the physical
701+
// planner sees it, so it works correctly at runtime.
702+
aggregate_function!(grouping);
702703

703704
#[pyfunction]
704705
#[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))]
@@ -736,6 +737,19 @@ pub fn approx_percentile_cont_with_weight(
736737
add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
737738
}
738739

740+
#[pyfunction]
741+
#[pyo3(signature = (sort_expression, percentile, filter=None))]
742+
pub fn percentile_cont(
743+
sort_expression: PySortExpr,
744+
percentile: f64,
745+
filter: Option<PyExpr>,
746+
) -> PyDataFusionResult<PyExpr> {
747+
let agg_fn =
748+
functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile));
749+
750+
add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
751+
}
752+
739753
// We handle last_value explicitly because the signature expects an order_by
740754
// https://github.com/apache/datafusion/issues/12376
741755
#[pyfunction]
@@ -936,6 +950,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
936950
m.add_wrapped(wrap_pyfunction!(approx_median))?;
937951
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
938952
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
953+
m.add_wrapped(wrap_pyfunction!(percentile_cont))?;
939954
m.add_wrapped(wrap_pyfunction!(range))?;
940955
m.add_wrapped(wrap_pyfunction!(array_agg))?;
941956
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
@@ -981,7 +996,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
981996
m.add_wrapped(wrap_pyfunction!(floor))?;
982997
m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
983998
m.add_wrapped(wrap_pyfunction!(gcd))?;
984-
// m.add_wrapped(wrap_pyfunction!(grouping))?;
999+
m.add_wrapped(wrap_pyfunction!(grouping))?;
9851000
m.add_wrapped(wrap_pyfunction!(in_list))?;
9861001
m.add_wrapped(wrap_pyfunction!(initcap))?;
9871002
m.add_wrapped(wrap_pyfunction!(isnan))?;

python/datafusion/functions.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
"floor",
150150
"from_unixtime",
151151
"gcd",
152+
"grouping",
152153
"in_list",
153154
"initcap",
154155
"isnan",
@@ -216,6 +217,7 @@
216217
"order_by",
217218
"overlay",
218219
"percent_rank",
220+
"percentile_cont",
219221
"pi",
220222
"pow",
221223
"power",
@@ -286,6 +288,7 @@
286288
"uuid",
287289
"var",
288290
"var_pop",
291+
"var_population",
289292
"var_samp",
290293
"var_sample",
291294
"when",
@@ -3523,6 +3526,47 @@ def approx_percentile_cont_with_weight(
35233526
)
35243527

35253528

3529+
def percentile_cont(
3530+
sort_expression: Expr | SortExpr,
3531+
percentile: float,
3532+
filter: Expr | None = None,
3533+
) -> Expr:
3534+
"""Computes the exact percentile of input values using continuous interpolation.
3535+
3536+
Unlike :py:func:`approx_percentile_cont`, this function computes the exact
3537+
percentile value rather than an approximation.
3538+
3539+
If using the builder functions described in ref:`_aggregation` this function ignores
3540+
the options ``order_by``, ``null_treatment``, and ``distinct``.
3541+
3542+
Args:
3543+
sort_expression: Values for which to find the percentile
3544+
percentile: This must be between 0.0 and 1.0, inclusive
3545+
filter: If provided, only compute against rows for which the filter is True
3546+
3547+
Examples:
3548+
>>> ctx = dfn.SessionContext()
3549+
>>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
3550+
>>> result = df.aggregate(
3551+
... [], [dfn.functions.percentile_cont(
3552+
... dfn.col("a"), 0.5
3553+
... ).alias("v")])
3554+
>>> result.collect_column("v")[0].as_py()
3555+
3.0
3556+
3557+
>>> result = df.aggregate(
3558+
... [], [dfn.functions.percentile_cont(
3559+
... dfn.col("a"), 0.5,
3560+
... filter=dfn.col("a") > dfn.lit(1.0),
3561+
... ).alias("v")])
3562+
>>> result.collect_column("v")[0].as_py()
3563+
3.5
3564+
"""
3565+
sort_expr_raw = sort_or_default(sort_expression)
3566+
filter_raw = filter.expr if filter is not None else None
3567+
return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw))
3568+
3569+
35263570
def array_agg(
35273571
expression: Expr,
35283572
distinct: bool = False,
@@ -3581,6 +3625,30 @@ def array_agg(
35813625
)
35823626

35833627

3628+
def grouping(
3629+
expression: Expr,
3630+
distinct: bool | None = None,
3631+
filter: Expr | None = None,
3632+
) -> Expr:
3633+
"""Returns 1 if the data is aggregated across the specified column, or 0 otherwise.
3634+
3635+
This function is used with ``GROUPING SETS``, ``CUBE``, or ``ROLLUP`` to
3636+
distinguish between aggregated and non-aggregated rows. In a regular
3637+
``GROUP BY`` without grouping sets, it always returns 0.
3638+
3639+
Note: The ``grouping`` aggregate function is rewritten by the query
3640+
optimizer before execution, so it works correctly even though its
3641+
physical plan is not directly implemented.
3642+
3643+
Args:
3644+
expression: The column to check grouping status for
3645+
distinct: If True, compute on distinct values only
3646+
filter: If provided, only compute against rows for which the filter is True
3647+
"""
3648+
filter_raw = filter.expr if filter is not None else None
3649+
return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw))
3650+
3651+
35843652
def avg(
35853653
expression: Expr,
35863654
filter: Expr | None = None,
@@ -4052,6 +4120,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr:
40524120
return Expr(f.var_pop(expression.expr, filter=filter_raw))
40534121

40544122

4123+
def var_population(expression: Expr, filter: Expr | None = None) -> Expr:
4124+
"""Computes the population variance of the argument.
4125+
4126+
See Also:
4127+
This is an alias for :py:func:`var_pop`.
4128+
"""
4129+
return var_pop(expression, filter)
4130+
4131+
40554132
def var_samp(expression: Expr, filter: Expr | None = None) -> Expr:
40564133
"""Computes the sample variance of the argument.
40574134

python/tests/test_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,3 +1435,49 @@ def test_coalesce(df):
14351435
assert result.column(0) == pa.array(
14361436
["Hello", "fallback", "!"], type=pa.string_view()
14371437
)
1438+
1439+
1440+
def test_percentile_cont():
1441+
ctx = SessionContext()
1442+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1443+
result = df.aggregate(
1444+
[], [f.percentile_cont(column("a"), 0.5).alias("v")]
1445+
).collect()[0]
1446+
assert result.column(0)[0].as_py() == 3.0
1447+
1448+
1449+
def test_percentile_cont_with_filter():
1450+
ctx = SessionContext()
1451+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1452+
result = df.aggregate(
1453+
[],
1454+
[
1455+
f.percentile_cont(
1456+
column("a"), 0.5, filter=column("a") > literal(1.0)
1457+
).alias("v")
1458+
],
1459+
).collect()[0]
1460+
assert result.column(0)[0].as_py() == 3.5
1461+
1462+
1463+
def test_grouping():
1464+
ctx = SessionContext()
1465+
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
1466+
# In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows.
1467+
# Note: grouping() must not be aliased directly in the aggregate expression list
1468+
# due to an upstream DataFusion analyzer limitation (the ResolveGroupingFunction
1469+
# rule doesn't unwrap Alias nodes). Apply aliases via a follow-up select instead.
1470+
result = df.aggregate(
1471+
[column("a")], [f.grouping(column("a")), f.sum(column("b")).alias("s")]
1472+
).collect()
1473+
grouping_col = pa.concat_arrays([batch.column(1) for batch in result]).to_pylist()
1474+
assert all(v == 0 for v in grouping_col)
1475+
1476+
1477+
def test_var_population():
1478+
ctx = SessionContext()
1479+
df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]})
1480+
result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0]
1481+
# var_population is an alias for var_pop
1482+
expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0]
1483+
assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10

0 commit comments

Comments
 (0)