Skip to content

Commit 6b61c49

Browse files
author
shuo.cs
committed
fix test
1 parent 0573b65 commit 6b61c49

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.apache.flink.table.api.DataTypes;
2222
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedCallExpression;
2324
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2425
import org.apache.flink.table.types.DataType;
2526
import org.apache.flink.table.types.logical.DecimalType;
@@ -28,11 +29,13 @@
2829
import java.math.BigDecimal;
2930

3031
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
32+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
3133
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
3234
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
3335
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
3436
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
3537
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
38+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
3639

3740
/** built-in sum0 aggregate function. */
3841
public abstract class Sum0AggFunction extends DeclarativeAggregateFunction {
@@ -56,20 +59,25 @@ public DataType[] getAggBufferTypes() {
5659
@Override
5760
public Expression[] accumulateExpressions() {
5861
return new Expression[] {
59-
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0)))
62+
/* sum0 = */ adjustSumType(ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0))))
6063
};
6164
}
6265

6366
@Override
6467
public Expression[] retractExpressions() {
6568
return new Expression[] {
66-
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0)))
69+
/* sum0 = */ adjustSumType(
70+
ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0))))
6771
};
6872
}
6973

7074
@Override
7175
public Expression[] mergeExpressions() {
72-
return new Expression[] {/* sum0 = */ plus(sum0, mergeOperand(sum0))};
76+
return new Expression[] {/* sum0 = */ adjustSumType(plus(sum0, mergeOperand(sum0)))};
77+
}
78+
79+
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
80+
return cast(sumExpr, typeLiteral(getResultType()));
7381
}
7482

7583
@Override

flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,21 @@
1818

1919
package org.apache.flink.table.planner.runtime.stream.sql
2020

21+
import org.apache.flink.api.java.typeutils.RowTypeInfo
2122
import org.apache.flink.api.scala._
2223
import org.apache.flink.table.api._
2324
import org.apache.flink.table.api.bridge.scala._
2425
import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
2526
import org.apache.flink.table.planner.runtime.utils.TimeTestUtil.EventTimeProcessOperator
2627
import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils.{CountNullNonNull, CountPairs, LargerThanCount}
2728
import org.apache.flink.table.planner.runtime.utils.{StreamingWithStateTestBase, TestData, TestingAppendSink}
29+
import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
2830
import org.apache.flink.types.Row
2931

3032
import org.junit.Assert._
3133
import org.junit._
3234
import org.junit.runner.RunWith
3335
import org.junit.runners.Parameterized
34-
3536
import scala.collection.mutable
3637

3738
@RunWith(classOf[Parameterized])
@@ -1131,4 +1132,33 @@ class OverAggregateITCase(mode: StateBackendMode) extends StreamingWithStateTest
11311132
"B,Hello World,10,7")
11321133
assertEquals(expected, sink.getAppendResults)
11331134
}
1135+
1136+
@Test
1137+
def testDecimalSum0(): Unit = {
1138+
val data = new mutable.MutableList[Row]
1139+
data.+=(Row.of(BigDecimal(1.11).bigDecimal))
1140+
data.+=(Row.of(BigDecimal(2.22).bigDecimal))
1141+
data.+=(Row.of(BigDecimal(3.33).bigDecimal))
1142+
data.+=(Row.of(BigDecimal(4.44).bigDecimal))
1143+
1144+
env.setParallelism(1)
1145+
val rowType = new RowTypeInfo(BigDecimalTypeInfo.of(38, 18))
1146+
val t = failingDataSource(data)(rowType).toTable(tEnv, 'd, 'proctime.proctime)
1147+
tEnv.registerTable("T", t)
1148+
1149+
val sqlQuery = "select sum(d) over (ORDER BY proctime rows between unbounded preceding " +
1150+
"and current row) from T"
1151+
1152+
val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row]
1153+
val sink = new TestingAppendSink
1154+
result.addSink(sink)
1155+
env.execute()
1156+
1157+
val expected = List(
1158+
"1.110000000000000000",
1159+
"3.330000000000000000",
1160+
"6.660000000000000000",
1161+
"11.100000000000000000")
1162+
assertEquals(expected, sink.getAppendResults)
1163+
}
11341164
}

0 commit comments

Comments
 (0)