Skip to content

Commit 038691b

Browse files
viiryaclaude
andcommitted
feat: support SQL aggregate FILTER (WHERE ...) clause in native execution
Previously, Comet fell back to Spark for any aggregation containing a FILTER (WHERE ...) clause (e.g. SUM(x) FILTER (WHERE y > 0)). The native SumInt/SumDecimal accumulators already received opt_filter support in the previous commit. This commit wires the full pipeline: Proto (expr.proto): - Add optional Expr filter = 89 to AggExpr message Scala serialization (QueryPlanSerde.scala): - In aggExprToProto, serialize aggExpr.filter into the proto when aggExpr.mode == Partial (filters are only meaningful in partial mode) - If the filter expression cannot be serialized, fall back gracefully Native planner (planner.rs): - Build per-aggregate filter PhysicalExpr from agg_expr.filter - Pass to AggregateExec::try_new instead of vec![None; num_agg] Comet planner (operators.scala): - Remove the blanket fallback guard for aggregate expressions with filter Tests (aggregate_filter.sql): - Update queries from expect_fallback to plain query mode now that native execution is supported; tests verify results match Spark Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3e44050 commit 038691b

5 files changed

Lines changed: 42 additions & 17 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,12 +975,27 @@ impl PhysicalPlanner {
975975

976976
let num_agg = agg.agg_exprs.len();
977977
let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect();
978+
979+
// Build per-aggregate filter expressions from the FILTER (WHERE ...) clause.
980+
// Filters are only present in Partial mode; Final/PartialMerge always get None.
981+
let filter_exprs: Result<Vec<Option<Arc<dyn PhysicalExpr>>>, ExecutionError> = agg
982+
.agg_exprs
983+
.iter()
984+
.map(|expr| {
985+
if let Some(f) = expr.filter.as_deref() {
986+
self.create_expr(f, Arc::clone(&schema)).map(Some)
987+
} else {
988+
Ok(None)
989+
}
990+
})
991+
.collect();
992+
978993
let aggregate: Arc<dyn ExecutionPlan> = Arc::new(
979994
datafusion::physical_plan::aggregates::AggregateExec::try_new(
980995
mode,
981996
group_by,
982997
aggr_expr,
983-
vec![None; num_agg], // no filter expressions
998+
filter_exprs?,
984999
Arc::clone(&child.native_plan),
9851000
Arc::clone(&schema),
9861001
)?,

native/proto/src/proto/expr.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ message AggExpr {
141141
BloomFilterAgg bloomFilterAgg = 16;
142142
}
143143

144+
// Optional filter expression for SQL FILTER (WHERE ...) clause.
145+
// Only set in Partial aggregation mode; absent in Final/PartialMerge.
146+
optional Expr filter = 89;
147+
144148
// Optional QueryContext for error reporting (contains SQL text and position)
145149
optional QueryContext query_context = 90;
146150

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,13 +517,25 @@ object QueryPlanSerde extends Logging with CometExprShim {
517517
}
518518

519519
// Attach QueryContext and expr_id to the aggregate expression
520-
protoAggExprOpt.map { protoAggExpr =>
520+
protoAggExprOpt.flatMap { protoAggExpr =>
521521
val builder = protoAggExpr.toBuilder
522522
builder.setExprId(nextExprId())
523+
524+
// Serialize FILTER (WHERE ...) clause if present.
525+
// The filter is only meaningful in Partial mode; Final/PartialMerge never set it.
526+
if (aggExpr.filter.isDefined && aggExpr.mode == Partial) {
527+
val filterProto = exprToProto(aggExpr.filter.get, inputs, binding)
528+
if (filterProto.isEmpty) {
529+
withInfo(aggExpr, aggExpr.filter.get)
530+
return None
531+
}
532+
builder.setFilter(filterProto.get)
533+
}
534+
523535
extractQueryContext(fn).foreach { ctx =>
524536
builder.setQueryContext(ctx)
525537
}
526-
builder.build()
538+
Some(builder.build())
527539
}
528540
}
529541

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,12 +1362,6 @@ trait CometBaseAggregate {
13621362
return None
13631363
}
13641364

1365-
// Aggregate expressions with filter are not supported yet.
1366-
if (aggregateExpressions.exists(_.filter.isDefined)) {
1367-
withInfo(aggregate, "Aggregate expression with filter is not supported")
1368-
return None
1369-
}
1370-
13711365
if (groupingExpressions.exists(expr =>
13721366
expr.dataType match {
13731367
case _: MapType => true

spark/src/test/resources/sql-tests/expressions/aggregate/aggregate_filter.sql

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,33 +37,33 @@ INSERT INTO test_agg_filter VALUES
3737
('b', NULL, NULL, NULL, true)
3838

3939
-- Basic FILTER on SUM(int)
40-
query expect_fallback(Aggregate expression with filter is not supported)
40+
query
4141
SELECT SUM(i) FILTER (WHERE flag = true) FROM test_agg_filter
4242

4343
-- FILTER on SUM with GROUP BY
44-
query expect_fallback(Aggregate expression with filter is not supported)
44+
query
4545
SELECT grp, SUM(i) FILTER (WHERE flag = true) FROM test_agg_filter GROUP BY grp ORDER BY grp
4646

4747
-- FILTER on SUM(long)
48-
query expect_fallback(Aggregate expression with filter is not supported)
48+
query
4949
SELECT SUM(l) FILTER (WHERE flag = true) FROM test_agg_filter
5050

5151
-- FILTER on SUM(decimal)
52-
query expect_fallback(Aggregate expression with filter is not supported)
52+
query
5353
SELECT SUM(d) FILTER (WHERE flag = true) FROM test_agg_filter
5454

5555
-- Multiple aggregates: one with filter, one without
56-
query expect_fallback(Aggregate expression with filter is not supported)
56+
query
5757
SELECT SUM(i), SUM(i) FILTER (WHERE flag = true) FROM test_agg_filter
5858

5959
-- FILTER with NULL rows: NULLs should not be included even when filter passes
60-
query expect_fallback(Aggregate expression with filter is not supported)
60+
query
6161
SELECT grp, SUM(i) FILTER (WHERE flag = true) FROM test_agg_filter GROUP BY grp ORDER BY grp
6262

6363
-- FILTER with COUNT
64-
query expect_fallback(Aggregate expression with filter is not supported)
64+
query
6565
SELECT COUNT(*) FILTER (WHERE flag = true) FROM test_agg_filter
6666

6767
-- FILTER with COUNT GROUP BY
68-
query expect_fallback(Aggregate expression with filter is not supported)
68+
query
6969
SELECT grp, COUNT(*) FILTER (WHERE flag = true) FROM test_agg_filter GROUP BY grp ORDER BY grp

0 commit comments

Comments
 (0)