diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableWindow.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableWindow.java index 6197420b05ea..78ecce8821d7 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableWindow.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableWindow.java @@ -451,6 +451,7 @@ private static void sampleOfTheGeneratedWindowedAggregate() { hasRows, frameRowCount, partitionRowCount, jDecl, inputPhysTypeFinal); + final RelDataType inputRowType = inputPhysType.getRowType(); final Function> rexArguments = agg -> { List argList = agg.call.getArgList(); List inputTypes = @@ -464,7 +465,7 @@ private static void sampleOfTheGeneratedWindowedAggregate() { return args; }; - implementAdd(aggs, builder7, resultContextBuilder, rexArguments, jDecl); + implementAdd(aggs, builder7, resultContextBuilder, rexArguments, jDecl, inputRowType); BlockStatement forBlock = builder7.toBlock(); // Don't run the aggregate function if current row is excluded @@ -866,7 +867,8 @@ private static void implementAdd(List aggs, final BlockBuilder builder7, final Function frame, final Function> rexArguments, - final DeclarationStatement jDecl) { + final DeclarationStatement jDecl, + final RelDataType inputRowType) { for (final AggImpState agg : aggs) { final WinAggAddContext addContext = new WinAggAddContextImpl(builder7, requireNonNull(agg.state, "agg.state"), frame) { @@ -879,7 +881,9 @@ private static void implementAdd(List aggs, } @Override public @Nullable RexNode rexFilterArgument() { - return null; // REVIEW + return agg.call.filterArg < 0 + ? null + : RexInputRef.of(agg.call.filterArg, inputRowType); } }; agg.implementor.implementAdd(requireNonNull(agg.context, "agg.context"), addContext); diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java index 549bddbf724d..a71fcf36d0a1 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java @@ -417,6 +417,7 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EVERY; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EXP; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EXTRACT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FILTER; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FIRST_VALUE; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FLOOR; import static org.apache.calcite.sql.fun.SqlStdOperatorTable.FUSION; @@ -1250,6 +1251,7 @@ void populate2() { NotJsonImplementor.of( new MethodImplementor(BuiltInMethod.IS_JSON_SCALAR.method, NullPolicy.NONE, false))); + define(FILTER, new FilterImplementor()); } /** Third step of population. */ @@ -5111,4 +5113,18 @@ private static class ReplaceImplementor extends AbstractRexCallImplementor { operand0, operand1, operand2, Expressions.constant(isCaseSensitive)); } } + + /** Implementor for the FILTER operator. */ + private static class FilterImplementor extends AbstractRexCallImplementor { + FilterImplementor() { + super("filter", NullPolicy.NONE, false); + } + + @Override Expression implementSafe(RexToLixTranslator translator, RexCall call, + List argValueList) { + final Expression value = argValueList.get(0); + final Expression condition = argValueList.get(1); + return Expressions.condition(condition, value, NULL_EXPR); + } + } } diff --git a/core/src/main/java/org/apache/calcite/sql/SqlOverOperator.java b/core/src/main/java/org/apache/calcite/sql/SqlOverOperator.java index cc42a9ad0c41..b1ed13808dac 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlOverOperator.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlOverOperator.java @@ -64,9 +64,14 @@ public SqlOverOperator() { assert call.getOperator() == this; assert call.operandCount() == 2; SqlCall aggCall = call.operand(0); + boolean hasFilter = false; switch (aggCall.getKind()) { case RESPECT_NULLS: case IGNORE_NULLS: + case FILTER: + if (aggCall.getKind() == SqlKind.FILTER) { + hasFilter = true; + } validator.validateCall(aggCall, scope); aggCall = aggCall.operand(0); break; @@ -76,6 +81,11 @@ public SqlOverOperator() { if (!aggCall.getOperator().isAggregator()) { throw validator.newValidationError(aggCall, RESOURCE.overNonAggregate()); } + // COUNT(DISTINCT) is not allowed in window functions with FILTER + if (hasFilter && aggCall.getKind() == SqlKind.COUNT + && aggCall.getFunctionQuantifier() != null) { + throw validator.newValidationError(aggCall, RESOURCE.overNonAggregate()); + } final SqlNode window = call.operand(1); validator.validateWindow(window, scope, aggCall); } @@ -102,7 +112,14 @@ public SqlOverOperator() { SqlNode window = call.operand(1); SqlWindow w = validator.resolveWindow(window, scope); - final SqlCall aggCall = (SqlCall) agg; + SqlCall aggCall = (SqlCall) agg; + // Unwrap FILTER, RESPECT_NULLS, or IGNORE_NULLS to get the actual aggregate call + while (aggCall != null + && (aggCall.getKind() == SqlKind.FILTER + || aggCall.getKind() == SqlKind.RESPECT_NULLS + || aggCall.getKind() == SqlKind.IGNORE_NULLS)) { + aggCall = aggCall.operand(0); + } SqlCallBinding opBinding = new SqlCallBinding(validator, scope, aggCall) { @Override public boolean hasEmptyGroup() { diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java index 5198762c106b..df4f0e70d80b 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java @@ -3513,13 +3513,33 @@ void testWinPartClause() { * Validator rejects FILTER in OVER windows. */ @Test void testOverFilter() { winSql("SELECT deptno,\n" - + " ^COUNT(DISTINCT deptno) FILTER (WHERE deptno > 10)^\n" + + " ^COUNT(DISTINCT deptno)^ FILTER (WHERE deptno > 10)\n" + "OVER win AS agg\n" + "FROM emp\n" - + "WINDOW win AS (PARTITION BY empno)") + + "WINDOW win AS (PARTITION BY empno)") .fails("OVER must be applied to aggregate function"); } + /** Test case for [CALCITE-7595] + * Support FILTER clause with window functions. */ + @Test void testFilterWithOver() { + winSql("SELECT SUM(sal) FILTER (WHERE sal > 100) OVER (PARTITION BY deptno) FROM emp") + .ok(); + } + + @Test void testFilterWithOverAndDistinct() { + winSql("SELECT SUM(DISTINCT sal) FILTER (WHERE sal > 100) OVER (ORDER BY deptno) FROM emp") + .ok(); + } + + @Test void testMultipleFiltersWithOver() { + winSql("SELECT " + + "COUNT(*) FILTER (WHERE empno > 100) OVER (PARTITION BY deptno), " + + "SUM(sal) FILTER (WHERE sal > 0) OVER (PARTITION BY deptno) " + + "FROM emp") + .ok(); + } + @Test void testOverInOrderBy() { winSql("select sum(deptno) over ^(order by sum(deptno)\n" + "over(order by deptno))^ from emp") diff --git a/core/src/test/resources/sql/winagg.iq b/core/src/test/resources/sql/winagg.iq index 6a32b3b3f7b0..a277ff7b993c 100644 --- a/core/src/test/resources/sql/winagg.iq +++ b/core/src/test/resources/sql/winagg.iq @@ -1173,4 +1173,116 @@ order by 1; (14 rows) !ok + +# [CALCITE-6442] Support FILTER clause with window functions + +# Test 1: FILTER with OVER on COUNT +select empno, deptno, + count(*) filter (where sal > 1500) over (partition by deptno) as filtered_count +from emp +order by empno; ++-------+--------+----------------+ +| EMPNO | DEPTNO | FILTERED_COUNT | ++-------+--------+----------------+ +| 7369 | 20 | 0 | +| 7566 | 20 | 5 | +| 7788 | 20 | 5 | +| 7876 | 20 | 0 | +| 7902 | 20 | 5 | +| 7782 | 10 | 3 | +| 7839 | 10 | 3 | +| 7934 | 10 | 0 | +| 7499 | 30 | 6 | +| 7521 | 30 | 0 | +| 7654 | 30 | 0 | +| 7698 | 30 | 6 | +| 7844 | 30 | 0 | +| 7900 | 30 | 0 | ++-------+--------+----------------+ +(14 rows) + +!ok + +# Test 2: FILTER with OVER on SUM +select empno, deptno, + sum(sal) filter (where comm is not null) over (partition by deptno) as filtered_sum +from emp +order by empno; ++-------+--------+--------------+ +| EMPNO | DEPTNO | FILTERED_SUM | ++-------+--------+--------------+ +| 7369 | 20 | | +| 7566 | 20 | | +| 7788 | 20 | | +| 7876 | 20 | | +| 7902 | 20 | | +| 7782 | 10 | | +| 7839 | 10 | | +| 7934 | 10 | | +| 7499 | 30 | 9400.00 | +| 7521 | 30 | 9400.00 | +| 7654 | 30 | 9400.00 | +| 7698 | 30 | | +| 7844 | 30 | 9400.00 | +| 7900 | 30 | | ++-------+--------+--------------+ +(14 rows) + +!ok + +# Test 3: Multiple FILTER with OVER on different aggregates +select empno, deptno, + count(*) filter (where sal > 1500) over (partition by deptno) as high_sal_count, + sum(sal) filter (where sal <= 1500) over (partition by deptno) as low_sal_sum +from emp +order by empno; ++-------+--------+----------------+-------------+ +| EMPNO | DEPTNO | HIGH_SAL_COUNT | LOW_SAL_SUM | ++-------+--------+----------------+-------------+ +| 7369 | 20 | 0 | 10875.00 | +| 7566 | 20 | 5 | | +| 7788 | 20 | 5 | | +| 7876 | 20 | 0 | 10875.00 | +| 7902 | 20 | 5 | | +| 7782 | 10 | 3 | | +| 7839 | 10 | 3 | | +| 7934 | 10 | 0 | 8750.00 | +| 7499 | 30 | 6 | | +| 7521 | 30 | 0 | 9400.00 | +| 7654 | 30 | 0 | 9400.00 | +| 7698 | 30 | 6 | | +| 7844 | 30 | 0 | 9400.00 | +| 7900 | 30 | 0 | 9400.00 | ++-------+--------+----------------+-------------+ +(14 rows) + +!ok + +# Test 4: FILTER with OVER and ORDER BY (running window) +select empno, deptno, sal, + sum(sal) filter (where sal > 1000) over (partition by deptno order by empno rows between unbounded preceding and current row) as running_sum +from emp +order by empno; ++-------+--------+---------+-------------+ +| EMPNO | DEPTNO | SAL | RUNNING_SUM | ++-------+--------+---------+-------------+ +| 7369 | 20 | 800.00 | | +| 7566 | 20 | 2975.00 | 3775.00 | +| 7788 | 20 | 3000.00 | 6775.00 | +| 7876 | 20 | 1100.00 | 7875.00 | +| 7902 | 20 | 3000.00 | 10875.00 | +| 7782 | 10 | 2450.00 | 2450.00 | +| 7839 | 10 | 5000.00 | 7450.00 | +| 7934 | 10 | 1300.00 | 8750.00 | +| 7499 | 30 | 1600.00 | 1600.00 | +| 7521 | 30 | 1250.00 | 2850.00 | +| 7654 | 30 | 1250.00 | 4100.00 | +| 7698 | 30 | 2850.00 | 6950.00 | +| 7844 | 30 | 1500.00 | 8450.00 | +| 7900 | 30 | 950.00 | | ++-------+--------+---------+-------------+ +(14 rows) + +!ok + # End winagg.iq diff --git a/site/_docs/reference.md b/site/_docs/reference.md index bdf4ada19998..bbcfa45b35af 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -2108,9 +2108,11 @@ Syntax: windowedAggregateCall: agg '(' [ ALL | DISTINCT ] value [, value ]* ')' [ RESPECT NULLS | IGNORE NULLS ] + [ FILTER '(' WHERE condition ')' ] [ WITHIN GROUP '(' ORDER BY orderItem [, orderItem ]* ')' ] OVER window | agg '(' '*' ')' + [ FILTER '(' WHERE condition ')' ] OVER window {% endhighlight %}