diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java index ac40cc9101a4..91cc550b3f4a 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java @@ -968,7 +968,8 @@ public interface Config extends RelRule.Config { .addAll(SqlKind.AVG_AGG_FUNCTIONS) .addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS) .add(SqlKind.SUM) - .build(); + .build().stream().filter(k -> k != SqlKind.REGR_COUNT) + .collect(ImmutableSet.toImmutableSet()); @Override default AggregateReduceFunctionsRule toRule() { return new AggregateReduceFunctionsRule(this); diff --git a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java index 95030b5b9053..25e76ea740dd 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java +++ b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java @@ -39,7 +39,6 @@ import org.apache.calcite.sql.SqlTimeLiteral; import org.apache.calcite.sql.SqlTimestampLiteral; import org.apache.calcite.sql.SqlUtil; -import org.apache.calcite.sql.fun.SqlCountAggFunction; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; @@ -388,7 +387,7 @@ public RexNode addAggCall(AggregateCall aggCall, int groupCount, List aggCalls, Map aggCallMapping, IntPredicate isNullable) { - if (aggCall.getAggregation() instanceof SqlCountAggFunction + if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) { final List args = aggCall.getArgList(); final List nullableArgs = nullableArgs(args, isNullable); diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java index 3007ad346593..fe5c41f3a4d0 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java @@ -55,9 +55,14 @@ public SqlCountAggFunction(String name) { public SqlCountAggFunction(String name, SqlOperandTypeChecker sqlOperandTypeChecker) { - super(name, null, SqlKind.COUNT, ReturnTypes.BIGINT, null, - sqlOperandTypeChecker, SqlFunctionCategory.NUMERIC, false, false, - Optionality.FORBIDDEN); + this(name, sqlOperandTypeChecker, SqlKind.COUNT); + } + + public SqlCountAggFunction(String name, + SqlOperandTypeChecker sqlOperandTypeChecker, + SqlKind sqlKind) { + super(name, null, sqlKind, ReturnTypes.BIGINT, null, sqlOperandTypeChecker, + SqlFunctionCategory.NUMERIC, false, false, Optionality.FORBIDDEN); } //~ Methods ---------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java index 28023ed83d29..976e08596d07 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java @@ -29,7 +29,7 @@ */ public class SqlRegrCountAggFunction extends SqlCountAggFunction { public SqlRegrCountAggFunction(SqlKind kind) { - super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC); + super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC, kind); checkArgument(SqlKind.REGR_COUNT == kind, "unsupported sql kind: " + kind); } }