Skip to content

Commit e9a25d9

Browse files
committed
test fixes
1 parent 3736d19 commit e9a25d9

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/CollapseAggregates.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object CollapseAggregates extends Rule[SparkPlan] {
4343
if checkIfAggregatesCanBeCollapsed(parent, child) =>
4444
val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete))
4545
HashAggregateExec(
46-
requiredChildDistributionExpressions = None,
46+
requiredChildDistributionExpressions = parent.requiredChildDistributionExpressions,
4747
groupingExpressions = child.groupingExpressions,
4848
aggregateExpressions = completeAggregateExpressions,
4949
aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute),
@@ -55,7 +55,7 @@ object CollapseAggregates extends Rule[SparkPlan] {
5555
if checkIfAggregatesCanBeCollapsed(parent, child) =>
5656
val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete))
5757
SortAggregateExec(
58-
requiredChildDistributionExpressions = None,
58+
requiredChildDistributionExpressions = parent.requiredChildDistributionExpressions,
5959
groupingExpressions = child.groupingExpressions,
6060
aggregateExpressions = completeAggregateExpressions,
6161
aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute),
@@ -67,7 +67,7 @@ object CollapseAggregates extends Rule[SparkPlan] {
6767
if checkIfAggregatesCanBeCollapsed(parent, child) =>
6868
val completeAggregateExpressions = child.aggregateExpressions.map(_.copy(mode = Complete))
6969
ObjectHashAggregateExec(
70-
requiredChildDistributionExpressions = None,
70+
requiredChildDistributionExpressions = parent.requiredChildDistributionExpressions,
7171
groupingExpressions = child.groupingExpressions,
7272
aggregateExpressions = completeAggregateExpressions,
7373
aggregateAttributes = completeAggregateExpressions.map(_.resultAttribute),

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ class HiveUDAFSuite extends QueryTest
9898

9999
test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
100100
withTempView("v") {
101-
spark.range(100).createTempView("v")
101+
// Setting numPartitions > 1 explicitly so that we get two Physical aggregation nodes
102+
spark.range(0, 100, 1, 10).createTempView("v")
102103
val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")
103104

104105
val aggs = collect(df.queryExecution.executedPlan) {

0 commit comments

Comments
 (0)