From 62a6649044199c4188763791c71d13d8954cfa00 Mon Sep 17 00:00:00 2001 From: chloe-zh Date: Wed, 3 Jun 2026 23:24:21 -0700 Subject: [PATCH] Fix REGR_COUNT arguments being dropped during SqlToRel conversion REGR_COUNT(y, x) was incorrectly converted to REGR_COUNT(*) because SqlRegrCountAggFunction extends SqlCountAggFunction, which: 1. Hardcodes SqlKind.COUNT in its constructor, causing REGR_COUNT to report the wrong SqlKind. 2. Causes RexBuilder.addAggCall() to match REGR_COUNT via SqlKind.COUNT and strip non-nullable arguments (an optimization valid only for COUNT). 3. Masked a missing case handler in AggregateReduceFunctionsRule, since SqlKind.COUNT is not in functionsToReduce. Fix: - Add a protected constructor to SqlCountAggFunction that accepts SqlKind, allowing subclasses to specify their own kind. - Update SqlRegrCountAggFunction to pass SqlKind.REGR_COUNT. - Guard the nullable-args optimization in RexBuilder.addAggCall() to check SqlKind == COUNT instead of instanceof. - Add case REGR_COUNT in AggregateReduceFunctionsRule to preserve it as-is, since it is irreducible. --- .../rel/rules/AggregateReduceFunctionsRule.java | 3 ++- .../main/java/org/apache/calcite/rex/RexBuilder.java | 3 +-- .../apache/calcite/sql/fun/SqlCountAggFunction.java | 11 ++++++++--- .../calcite/sql/fun/SqlRegrCountAggFunction.java | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java index ac40cc9101a4..91cc550b3f4a 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java @@ -968,7 +968,8 @@ public interface Config extends RelRule.Config { .addAll(SqlKind.AVG_AGG_FUNCTIONS) .addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS) .add(SqlKind.SUM) - .build(); + .build().stream().filter(k -> k != SqlKind.REGR_COUNT) + .collect(ImmutableSet.toImmutableSet()); @Override default AggregateReduceFunctionsRule toRule() { return new AggregateReduceFunctionsRule(this); diff --git a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java index 95030b5b9053..25e76ea740dd 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java +++ b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java @@ -39,7 +39,6 @@ import org.apache.calcite.sql.SqlTimeLiteral; import org.apache.calcite.sql.SqlTimestampLiteral; import org.apache.calcite.sql.SqlUtil; -import org.apache.calcite.sql.fun.SqlCountAggFunction; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; @@ -388,7 +387,7 @@ public RexNode addAggCall(AggregateCall aggCall, int groupCount, List aggCalls, Map aggCallMapping, IntPredicate isNullable) { - if (aggCall.getAggregation() instanceof SqlCountAggFunction + if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) { final List args = aggCall.getArgList(); final List nullableArgs = nullableArgs(args, isNullable); diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java index 3007ad346593..fe5c41f3a4d0 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java @@ -55,9 +55,14 @@ public SqlCountAggFunction(String name) { public SqlCountAggFunction(String name, SqlOperandTypeChecker sqlOperandTypeChecker) { - super(name, null, SqlKind.COUNT, ReturnTypes.BIGINT, null, - sqlOperandTypeChecker, SqlFunctionCategory.NUMERIC, false, false, - Optionality.FORBIDDEN); + this(name, sqlOperandTypeChecker, SqlKind.COUNT); + } + + public SqlCountAggFunction(String name, + SqlOperandTypeChecker sqlOperandTypeChecker, + SqlKind sqlKind) { + super(name, null, sqlKind, ReturnTypes.BIGINT, null, sqlOperandTypeChecker, + SqlFunctionCategory.NUMERIC, false, false, Optionality.FORBIDDEN); } //~ Methods ---------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java index 28023ed83d29..976e08596d07 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java @@ -29,7 +29,7 @@ */ public class SqlRegrCountAggFunction extends SqlCountAggFunction { public SqlRegrCountAggFunction(SqlKind kind) { - super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC); + super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC, kind); checkArgument(SqlKind.REGR_COUNT == kind, "unsupported sql kind: " + kind); } }