@@ -25,7 +25,7 @@ import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
2525import  org .apache .spark .sql .catalyst .plans ._ 
2626import  org .apache .spark .sql .catalyst .rules .Rule 
2727import  org .apache .spark .sql .execution ._ 
28- import  org .apache .spark .sql .execution .aggregate .HashAggregateExec 
28+ import  org .apache .spark .sql .execution .aggregate .{ HashAggregateExec ,  ObjectHashAggregateExec ,  SortAggregateExec } 
2929import  org .apache .spark .sql .execution .exchange .{EnsureRequirements , ShuffleExchangeExec }
3030import  org .apache .spark .sql .execution .joins .SortMergeJoinExec 
3131import  org .apache .spark .sql .internal .SQLConf 
@@ -133,13 +133,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
133133
134134  private  def  canSplitLeftSide (joinType : JoinType , plan : SparkPlan ) =  {
135135    (joinType ==  Inner  ||  joinType ==  Cross  ||  joinType ==  LeftSemi  || 
136-       joinType ==  LeftAnti  ||  joinType ==  LeftOuter ) && 
137-       plan.find(_.isInstanceOf [HashAggregateExec ]).isEmpty
136+       joinType ==  LeftAnti  ||  joinType ==  LeftOuter ) &&  ! containsAggregateExec(plan)
138137  }
139138
140139  private  def  canSplitRightSide (joinType : JoinType , plan : SparkPlan ) =  {
141-     (joinType ==  Inner  ||  joinType ==  Cross  ||  joinType ==  RightOuter ) && 
142-       plan.find(_.isInstanceOf [HashAggregateExec ]).isEmpty
140+     (joinType ==  Inner  ||  joinType ==  Cross  || 
141+       joinType ==  RightOuter ) &&  ! containsAggregateExec(plan)
142+   }
143+ 
144+   private  def  containsAggregateExec (plan : SparkPlan ) =  {
145+     plan.find {
146+       case  _ : HashAggregateExec  =>  true 
147+       case  _ : SortAggregateExec  =>  true 
148+       case  _ : ObjectHashAggregateExec  =>  true 
149+       case  _ =>  false 
150+     }.isDefined
143151  }
144152
145153  private  def  getSizeInfo (medianSize : Long , sizes : Seq [Long ]):  String  =  {
0 commit comments