Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ private static void sampleOfTheGeneratedWindowedAggregate() {
hasRows, frameRowCount, partitionRowCount,
jDecl, inputPhysTypeFinal);

final RelDataType inputRowType = inputPhysType.getRowType();
final Function<AggImpState, List<RexNode>> rexArguments = agg -> {
List<Integer> argList = agg.call.getArgList();
List<RelDataType> inputTypes =
Expand All @@ -464,7 +465,7 @@ private static void sampleOfTheGeneratedWindowedAggregate() {
return args;
};

implementAdd(aggs, builder7, resultContextBuilder, rexArguments, jDecl);
implementAdd(aggs, builder7, resultContextBuilder, rexArguments, jDecl, inputRowType);
BlockStatement forBlock = builder7.toBlock();

// Don't run the aggregate function if current row is excluded
Expand Down Expand Up @@ -866,7 +867,8 @@ private static void implementAdd(List<AggImpState> aggs,
final BlockBuilder builder7,
final Function<BlockBuilder, WinAggFrameResultContext> frame,
final Function<AggImpState, List<RexNode>> rexArguments,
final DeclarationStatement jDecl) {
final DeclarationStatement jDecl,
final RelDataType inputRowType) {
for (final AggImpState agg : aggs) {
final WinAggAddContext addContext =
new WinAggAddContextImpl(builder7, requireNonNull(agg.state, "agg.state"), frame) {
Expand All @@ -879,7 +881,9 @@ private static void implementAdd(List<AggImpState> aggs,
}

@Override public @Nullable RexNode rexFilterArgument() {
return null; // REVIEW
return agg.call.filterArg < 0
? null
: RexInputRef.of(agg.call.filterArg, inputRowType);
}
};
agg.implementor.implementAdd(requireNonNull(agg.context, "agg.context"), addContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EVERY;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EXP;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EXTRACT;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FILTER;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FIRST_VALUE;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FLOOR;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FUSION;
Expand Down Expand Up @@ -1250,6 +1251,7 @@ void populate2() {
NotJsonImplementor.of(
new MethodImplementor(BuiltInMethod.IS_JSON_SCALAR.method,
NullPolicy.NONE, false)));
define(FILTER, new FilterImplementor());
}

/** Third step of population. */
Expand Down Expand Up @@ -5111,4 +5113,18 @@ private static class ReplaceImplementor extends AbstractRexCallImplementor {
operand0, operand1, operand2, Expressions.constant(isCaseSensitive));
}
}

/** Implementor for the FILTER operator. */
private static class FilterImplementor extends AbstractRexCallImplementor {
FilterImplementor() {
super("filter", NullPolicy.NONE, false);
}

@Override Expression implementSafe(RexToLixTranslator translator, RexCall call,
List<Expression> argValueList) {
final Expression value = argValueList.get(0);
final Expression condition = argValueList.get(1);
return Expressions.condition(condition, value, NULL_EXPR);
}
}
}
10 changes: 9 additions & 1 deletion core/src/main/java/org/apache/calcite/sql/SqlOverOperator.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public SqlOverOperator() {
switch (aggCall.getKind()) {
case RESPECT_NULLS:
case IGNORE_NULLS:
case FILTER:
validator.validateCall(aggCall, scope);
aggCall = aggCall.operand(0);
break;
Expand Down Expand Up @@ -102,7 +103,14 @@ public SqlOverOperator() {
SqlNode window = call.operand(1);
SqlWindow w = validator.resolveWindow(window, scope);

final SqlCall aggCall = (SqlCall) agg;
SqlCall aggCall = (SqlCall) agg;
// Unwrap FILTER, RESPECT_NULLS, or IGNORE_NULLS to get the actual aggregate call
while (aggCall != null
&& (aggCall.getKind() == SqlKind.FILTER
|| aggCall.getKind() == SqlKind.RESPECT_NULLS
|| aggCall.getKind() == SqlKind.IGNORE_NULLS)) {
aggCall = aggCall.operand(0);
}

SqlCallBinding opBinding = new SqlCallBinding(validator, scope, aggCall) {
@Override public boolean hasEmptyGroup() {
Expand Down
26 changes: 23 additions & 3 deletions core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3513,11 +3513,31 @@ void testWinPartClause() {
* Validator rejects FILTER in OVER windows</a>. */
@Test void testOverFilter() {
winSql("SELECT deptno,\n"
+ " ^COUNT(DISTINCT deptno) FILTER (WHERE deptno > 10)^\n"
+ " COUNT(DISTINCT deptno) FILTER (WHERE deptno > 10)\n"
+ "OVER win AS agg\n"
+ "FROM emp\n"
+ "WINDOW win AS (PARTITION BY empno)")
.fails("OVER must be applied to aggregate function");
+ "WINDOW win AS (PARTITION BY empno)")
.ok();
}

/** Test case for <a href="https://issues.apache.org/jira/browse/CALCITE-7595">[CALCITE-7595]
* Support FILTER clause with window functions</a>. */
@Test void testFilterWithOver() {
winSql("SELECT SUM(sal) FILTER (WHERE sal > 100) OVER (PARTITION BY deptno) FROM emp")
.ok();
}

@Test void testFilterWithOverAndDistinct() {
winSql("SELECT SUM(DISTINCT sal) FILTER (WHERE sal > 100) OVER (ORDER BY deptno) FROM emp")
.ok();
}

@Test void testMultipleFiltersWithOver() {
winSql("SELECT "
+ "COUNT(*) FILTER (WHERE empno > 100) OVER (PARTITION BY deptno), "
+ "SUM(sal) FILTER (WHERE sal > 0) OVER (PARTITION BY deptno) "
+ "FROM emp")
.ok();
}

@Test void testOverInOrderBy() {
Expand Down
139 changes: 139 additions & 0 deletions core/src/test/resources/sql/winagg.iq
Original file line number Diff line number Diff line change
Expand Up @@ -1173,4 +1173,143 @@ order by 1;
(14 rows)

!ok

# [CALCITE-6442] Support FILTER clause with window functions

# Test 1: FILTER with OVER on COUNT
select empno, deptno,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these validated?

count(*) filter (where sal > 1500) over (partition by deptno) as filtered_count
from emp
order by empno;
+-------+--------+----------------+
| EMPNO | DEPTNO | FILTERED_COUNT |
+-------+--------+----------------+
| 7369 | 20 | 0 |
| 7566 | 20 | 5 |
| 7788 | 20 | 5 |
| 7876 | 20 | 0 |
| 7902 | 20 | 5 |
| 7782 | 10 | 3 |
| 7839 | 10 | 3 |
| 7934 | 10 | 0 |
| 7499 | 30 | 6 |
| 7521 | 30 | 0 |
| 7654 | 30 | 0 |
| 7698 | 30 | 6 |
| 7844 | 30 | 0 |
| 7900 | 30 | 0 |
+-------+--------+----------------+
(14 rows)

!ok

# Test 2: FILTER with OVER on SUM
select empno, deptno,
sum(sal) filter (where comm is not null) over (partition by deptno) as filtered_sum
from emp
order by empno;
+-------+--------+--------------+
| EMPNO | DEPTNO | FILTERED_SUM |
+-------+--------+--------------+
| 7369 | 20 | |
| 7566 | 20 | |
| 7788 | 20 | |
| 7876 | 20 | |
| 7902 | 20 | |
| 7782 | 10 | |
| 7839 | 10 | |
| 7934 | 10 | |
| 7499 | 30 | 9400.00 |
| 7521 | 30 | 9400.00 |
| 7654 | 30 | 9400.00 |
| 7698 | 30 | |
| 7844 | 30 | 9400.00 |
| 7900 | 30 | |
+-------+--------+--------------+
(14 rows)

!ok

# Test 3: FILTER with OVER and DISTINCT
select empno, deptno,
count(distinct sal) filter (where sal > 1000) over (partition by deptno) as filtered_count_distinct
from emp
order by empno;
+-------+--------+-------------------------+
| EMPNO | DEPTNO | FILTERED_COUNT_DISTINCT |
+-------+--------+-------------------------+
| 7369 | 20 | 0 |
| 7566 | 20 | 5 |
| 7788 | 20 | 5 |
| 7876 | 20 | 5 |
| 7902 | 20 | 5 |
| 7782 | 10 | 3 |
| 7839 | 10 | 3 |
| 7934 | 10 | 3 |
| 7499 | 30 | 6 |
| 7521 | 30 | 6 |
| 7654 | 30 | 6 |
| 7698 | 30 | 6 |
| 7844 | 30 | 6 |
| 7900 | 30 | 0 |
+-------+--------+-------------------------+
(14 rows)

!ok

# Test 4: Multiple FILTER with OVER on different aggregates
select empno, deptno,
count(*) filter (where sal > 1500) over (partition by deptno) as high_sal_count,
sum(sal) filter (where sal <= 1500) over (partition by deptno) as low_sal_sum
from emp
order by empno;
+-------+--------+----------------+-------------+
| EMPNO | DEPTNO | HIGH_SAL_COUNT | LOW_SAL_SUM |
+-------+--------+----------------+-------------+
| 7369 | 20 | 0 | 10875.00 |
| 7566 | 20 | 5 | |
| 7788 | 20 | 5 | |
| 7876 | 20 | 0 | 10875.00 |
| 7902 | 20 | 5 | |
| 7782 | 10 | 3 | |
| 7839 | 10 | 3 | |
| 7934 | 10 | 0 | 8750.00 |
| 7499 | 30 | 6 | |
| 7521 | 30 | 0 | 9400.00 |
| 7654 | 30 | 0 | 9400.00 |
| 7698 | 30 | 6 | |
| 7844 | 30 | 0 | 9400.00 |
| 7900 | 30 | 0 | 9400.00 |
+-------+--------+----------------+-------------+
(14 rows)

!ok

# Test 5: FILTER with OVER and ORDER BY (running window)
select empno, deptno, sal,
sum(sal) filter (where sal > 1000) over (partition by deptno order by empno rows between unbounded preceding and current row) as running_sum
from emp
order by empno;
+-------+--------+---------+-------------+
| EMPNO | DEPTNO | SAL | RUNNING_SUM |
+-------+--------+---------+-------------+
| 7369 | 20 | 800.00 | |
| 7566 | 20 | 2975.00 | 3775.00 |
| 7788 | 20 | 3000.00 | 6775.00 |
| 7876 | 20 | 1100.00 | 7875.00 |
| 7902 | 20 | 3000.00 | 10875.00 |
| 7782 | 10 | 2450.00 | 2450.00 |
| 7839 | 10 | 5000.00 | 7450.00 |
| 7934 | 10 | 1300.00 | 8750.00 |
| 7499 | 30 | 1600.00 | 1600.00 |
| 7521 | 30 | 1250.00 | 2850.00 |
| 7654 | 30 | 1250.00 | 4100.00 |
| 7698 | 30 | 2850.00 | 6950.00 |
| 7844 | 30 | 1500.00 | 8450.00 |
| 7900 | 30 | 950.00 | |
+-------+--------+---------+-------------+
(14 rows)

!ok

# End winagg.iq
Loading