1717
1818package org .apache .spark .sql .execution .adaptive
1919
20- import java .util .concurrent .atomic .AtomicBoolean
21- import java .util .concurrent .{LinkedBlockingDeque , BlockingQueue }
2220import java .util .{HashMap => JHashMap , Map => JMap }
21+ import java .util .concurrent .{BlockingQueue , LinkedBlockingDeque }
22+ import java .util .concurrent .atomic .AtomicBoolean
2323
24- import org .apache .spark .sql .catalyst .rules .Rule
25-
26- import scala .concurrent .ExecutionContext .Implicits .global
2724import scala .collection .mutable .ArrayBuffer
25+ import scala .concurrent .ExecutionContext .Implicits .global
2826
29- import org .apache .spark .{MapOutputStatistics , SimpleFutureAction , ShuffleDependency }
27+ import org .apache .spark .{MapOutputStatistics , ShuffleDependency , SimpleFutureAction }
3028import org .apache .spark .rdd .RDD
3129import org .apache .spark .sql .catalyst .InternalRow
3230import org .apache .spark .sql .catalyst .expressions ._
31+ import org .apache .spark .sql .catalyst .plans ._
3332import org .apache .spark .sql .execution .{CollapseCodegenStages , SortExec , SparkPlan }
3433import org .apache .spark .sql .execution .aggregate .TungstenAggregate
3534import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ShuffleExchange }
@@ -150,7 +149,7 @@ trait QueryFragment extends SparkPlan {
150149 }
151150
152151 case agg @ TungstenAggregate (_, _, _, _, _, _, input @ FragmentInput (_))
153- if (! input.isOptimized())=> {
152+ if (! input.isOptimized()) => {
154153 logInfo(" Begin optimize agg, operator =\n " + agg.toString)
155154 optimizeAggregate(agg, input)
156155 }
@@ -188,7 +187,7 @@ trait QueryFragment extends SparkPlan {
188187 Utils .estimatePartitionStartIndices(aggStatistics.toArray, minNumPostShufflePartitions,
189188 advisoryTargetPostShuffleInputSize)
190189 }
191- val shuffledRowRdd = childFragments(0 ).getExchange().preparePostShuffleRDD(
190+ val shuffledRowRdd = childFragments(0 ).getExchange().preparePostShuffleRDD(
192191 shuffleDependencies(fragmentsIndex.get(childFragments(0 ))), partitionStartIndices)
193192 childFragments(0 ).getFragmentInput().setShuffleRdd(shuffledRowRdd)
194193 childFragments(0 ).getFragmentInput().setOptimized()
@@ -223,7 +222,7 @@ trait QueryFragment extends SparkPlan {
223222
224223 val leftFragment = childFragments(0 )
225224 val rightFragment = childFragments(1 )
226- val leftShuffledRowRdd = leftFragment.getExchange().preparePostShuffleRDD(
225+ val leftShuffledRowRdd = leftFragment.getExchange().preparePostShuffleRDD(
227226 shuffleDependencies(fragmentsIndex.get(leftFragment)), partitionStartIndices)
228227 val rightShuffledRowRdd = rightFragment.getExchange().preparePostShuffleRDD(
229228 shuffleDependencies(fragmentsIndex.get(rightFragment)), partitionStartIndices)
@@ -238,7 +237,17 @@ trait QueryFragment extends SparkPlan {
238237 if (sqlContext.conf.autoBroadcastJoinThreshold > 0 ) {
239238 val leftSizeInBytes = childSizeInBytes(0 )
240239 val rightSizeInBytes = childSizeInBytes(1 )
241- if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
240+ val joinType = joinPlan.joinType
241+ def canBuildLeft (joinType : JoinType ): Boolean = joinType match {
242+ case Inner | RightOuter => true
243+ case _ => false
244+ }
245+ def canBuildRight (joinType : JoinType ): Boolean = joinType match {
246+ case Inner | LeftOuter | LeftSemi | LeftAnti => true
247+ case j : ExistenceJoin => true
248+ case _ => false
249+ }
250+ if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold && canBuildLeft(joinType)) {
242251 val keys = Utils .rewriteKeyExpr(joinPlan.leftKeys).map(
243252 BindReferences .bindReference(_, left.child.output))
244253 newOperator = BroadcastHashJoinExec (
@@ -249,7 +258,8 @@ trait QueryFragment extends SparkPlan {
249258 joinPlan.condition,
250259 BroadcastExchangeExec (HashedRelationBroadcastMode (keys), left.child),
251260 right.child)
252- } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
261+ } else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold
262+ && canBuildRight(joinType)) {
253263 val keys = Utils .rewriteKeyExpr(joinPlan.rightKeys).map(
254264 BindReferences .bindReference(_, right.child.output))
255265 newOperator = BroadcastHashJoinExec (
@@ -260,13 +270,13 @@ trait QueryFragment extends SparkPlan {
260270 joinPlan.condition,
261271 left.child,
262272 BroadcastExchangeExec (HashedRelationBroadcastMode (keys), right.child))
263- }
273+ }
264274 }
265275 newOperator
266276 }
267277
268278 /** Returns a string representation of the nodes in this tree */
269- override def treeString : String =
279+ override def treeString : String =
270280 executedPlan.generateTreeString(0 , Nil , new StringBuilder ).toString
271281
272282 override def simpleString : String = " QueryFragment"
0 commit comments