Skip to content

Commit 80bef0d

Browse files
committed
add more agg exec
1 parent 607eb08 commit 80bef0d

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.rules.Rule
2727
import org.apache.spark.sql.execution._
28-
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
28+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2929
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
3030
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
3131
import org.apache.spark.sql.internal.SQLConf
@@ -133,13 +133,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
133133

134134
private def canSplitLeftSide(joinType: JoinType, plan: SparkPlan) = {
135135
(joinType == Inner || joinType == Cross || joinType == LeftSemi ||
136-
joinType == LeftAnti || joinType == LeftOuter) &&
137-
plan.find(_.isInstanceOf[HashAggregateExec]).isEmpty
136+
joinType == LeftAnti || joinType == LeftOuter) && !containsAggregateExec(plan)
138137
}
139138

140139
private def canSplitRightSide(joinType: JoinType, plan: SparkPlan) = {
141-
(joinType == Inner || joinType == Cross || joinType == RightOuter) &&
142-
plan.find(_.isInstanceOf[HashAggregateExec]).isEmpty
140+
(joinType == Inner || joinType == Cross ||
141+
joinType == RightOuter) && !containsAggregateExec(plan)
142+
}
143+
144+
private def containsAggregateExec(plan: SparkPlan) = {
145+
plan.find {
146+
case _: HashAggregateExec => true
147+
case _: SortAggregateExec => true
148+
case _: ObjectHashAggregateExec => true
149+
case _ => false
150+
}.isDefined
143151
}
144152

145153
private def getSizeInfo(medianSize: Long, sizes: Seq[Long]): String = {

0 commit comments

Comments
 (0)