Skip to content

Commit 8c64e65

Browse files
committed
[FLINK-35854][table] Add LiteralAggFunction
1 parent 71ce822 commit 8c64e65

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.functions.aggfunctions;
20+
21+
import org.apache.flink.table.api.DataTypes;
22+
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
24+
import org.apache.flink.table.functions.DeclarativeAggregateFunction;
25+
import org.apache.flink.table.types.DataType;
26+
27+
import org.apache.calcite.rex.RexLiteral;
28+
29+
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
30+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
31+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
32+
33+
/**
34+
* Built-in literal aggregate function. This function is used for internal optimizations. For more
35+
* details see <a href="https://issues.apache.org/jira/browse/CALCITE-4334">CALCITE-4334</a>.
36+
*/
37+
public abstract class LiteralAggFunction extends DeclarativeAggregateFunction {
38+
39+
private final UnresolvedReferenceExpression literalAgg = unresolvedRef("literalAgg");
40+
private final RexLiteral rexLiteral;
41+
42+
public LiteralAggFunction(RexLiteral rexLiteral) {
43+
this.rexLiteral = rexLiteral;
44+
}
45+
46+
@Override
47+
public int operandCount() {
48+
return 0;
49+
}
50+
51+
@Override
52+
public UnresolvedReferenceExpression[] aggBufferAttributes() {
53+
return new UnresolvedReferenceExpression[] {literalAgg};
54+
}
55+
56+
@Override
57+
public DataType[] getAggBufferTypes() {
58+
return new DataType[] {getResultType()};
59+
}
60+
61+
@Override
62+
public Expression[] initialValuesExpressions() {
63+
return new Expression[] {nullOf(getResultType())};
64+
}
65+
66+
@Override
67+
public Expression[] accumulateExpressions() {
68+
return new Expression[] {literal(rexLiteral.getValue(), getResultType())};
69+
}
70+
71+
@Override
72+
public Expression[] retractExpressions() {
73+
return new Expression[] {literal(rexLiteral.getValue(), getResultType())};
74+
}
75+
76+
@Override
77+
public Expression[] mergeExpressions() {
78+
return new Expression[] {literal(rexLiteral.getValue(), getResultType())};
79+
}
80+
81+
@Override
82+
public Expression getValueExpression() {
83+
return literal(rexLiteral.getValue(), getResultType());
84+
}
85+
86+
/** Built-in Boolean Literal aggregate function. */
87+
public static class BooleanLiteralAggFunction extends LiteralAggFunction {
88+
89+
public BooleanLiteralAggFunction(RexLiteral rexLiteral) {
90+
super(rexLiteral);
91+
}
92+
93+
@Override
94+
public DataType getResultType() {
95+
return DataTypes.BOOLEAN();
96+
}
97+
}
98+
99+
/** Built-in Byte Literal aggregate function. */
100+
public static class ByteLiteralAggFunction extends LiteralAggFunction {
101+
102+
public ByteLiteralAggFunction(RexLiteral rexLiteral) {
103+
super(rexLiteral);
104+
}
105+
106+
@Override
107+
public DataType getResultType() {
108+
return DataTypes.TINYINT();
109+
}
110+
}
111+
112+
/** Built-in Short Literal aggregate function. */
113+
public static class ShortLiteralAggFunction extends LiteralAggFunction {
114+
115+
public ShortLiteralAggFunction(RexLiteral rexLiteral) {
116+
super(rexLiteral);
117+
}
118+
119+
@Override
120+
public DataType getResultType() {
121+
return DataTypes.SMALLINT();
122+
}
123+
}
124+
125+
/** Built-in Long Literal aggregate function. */
126+
public static class LongLiteralAggFunction extends LiteralAggFunction {
127+
128+
public LongLiteralAggFunction(RexLiteral rexLiteral) {
129+
super(rexLiteral);
130+
}
131+
132+
@Override
133+
public DataType getResultType() {
134+
return DataTypes.BIGINT();
135+
}
136+
}
137+
138+
/** Built-in Float Literal aggregate function. */
139+
public static class FloatLiteralAggFunction extends LiteralAggFunction {
140+
141+
public FloatLiteralAggFunction(RexLiteral rexLiteral) {
142+
super(rexLiteral);
143+
}
144+
145+
@Override
146+
public DataType getResultType() {
147+
return DataTypes.FLOAT();
148+
}
149+
}
150+
151+
/** Built-in Double Literal aggregate function. */
152+
public static class DoubleLiteralAggFunction extends LiteralAggFunction {
153+
154+
public DoubleLiteralAggFunction(RexLiteral rexLiteral) {
155+
super(rexLiteral);
156+
}
157+
158+
@Override
159+
public DataType getResultType() {
160+
return DataTypes.DOUBLE();
161+
}
162+
}
163+
}

flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.plan.utils
2020
import org.apache.flink.table.api.TableException
2121
import org.apache.flink.table.functions.{BuiltInFunctionDefinitions, DeclarativeAggregateFunction, UserDefinedFunction}
2222
import org.apache.flink.table.planner.functions.aggfunctions._
23+
import org.apache.flink.table.planner.functions.aggfunctions.LiteralAggFunction.{BooleanLiteralAggFunction, ByteLiteralAggFunction, DoubleLiteralAggFunction, FloatLiteralAggFunction, ShortLiteralAggFunction}
2324
import org.apache.flink.table.planner.functions.aggfunctions.SingleValueAggFunction._
2425
import org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFunction._
2526
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
@@ -31,7 +32,10 @@ import org.apache.flink.table.runtime.functions.aggregate.PercentileAggFunction.
3132
import org.apache.flink.table.types.logical._
3233
import org.apache.flink.table.types.logical.LogicalTypeRoot._
3334

35+
import org.apache.calcite.rel.`type`.RelDataType
3436
import org.apache.calcite.rel.core.AggregateCall
37+
import org.apache.calcite.rex.{RexLiteral, RexNode}
38+
import org.apache.calcite.sql.`type`.SqlTypeName
3539
import org.apache.calcite.sql.{SqlAggFunction, SqlJsonConstructorNullClause, SqlKind, SqlRankFunction}
3640
import org.apache.calcite.sql.fun._
3741

@@ -158,6 +162,9 @@ class AggFunctionFactory(
158162
val onNull = fn.asInstanceOf[SqlJsonArrayAggAggFunction].getNullClause
159163
new JsonArrayAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL)
160164

165+
case a: SqlAggFunction if a.getKind == SqlKind.LITERAL_AGG =>
166+
createLiteralAggFunction(call.getType, call.rexList.get(0))
167+
161168
case udagg: AggSqlFunction =>
162169
// Can not touch the literals, Calcite make them in previous RelNode.
163170
// In here, all inputs are input refs.
@@ -278,6 +285,27 @@ class AggFunctionFactory(
278285
}
279286
}
280287

288+
private def createLiteralAggFunction(
289+
relDataType: RelDataType,
290+
rexNode: RexNode): UserDefinedFunction = {
291+
relDataType.getSqlTypeName match {
292+
case SqlTypeName.BOOLEAN =>
293+
new BooleanLiteralAggFunction(rexNode.asInstanceOf[RexLiteral])
294+
case SqlTypeName.TINYINT =>
295+
new ByteLiteralAggFunction(rexNode.asInstanceOf[RexLiteral])
296+
case SqlTypeName.SMALLINT =>
297+
new ShortLiteralAggFunction(rexNode.asInstanceOf[RexLiteral])
298+
case SqlTypeName.FLOAT =>
299+
new FloatLiteralAggFunction(rexNode.asInstanceOf[RexLiteral])
300+
case SqlTypeName.DOUBLE =>
301+
new DoubleLiteralAggFunction(rexNode.asInstanceOf[RexLiteral])
302+
case t =>
303+
throw new TableException(
304+
s"Literal aggregate function does not support type: ''$t''.\n" +
305+
s"Please re-check the data type.")
306+
}
307+
}
308+
281309
private def createMinAggFunction(
282310
argTypes: Array[LogicalType],
283311
index: Int): UserDefinedFunction = {

0 commit comments

Comments
 (0)