Skip to content

Commit 5de46b0

Browse files
committed
add ut
1 parent 6f1105c commit 5de46b0

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ import java.nio.charset.StandardCharsets
2121
import java.sql.Timestamp
2222

2323
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.sql.execution.adaptive.QueryFragmentTransformer
2524
import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext}
2625
import org.apache.spark.sql.catalyst.InternalRow
2726
import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
2827
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
2928
import org.apache.spark.sql.catalyst.rules.Rule
3029
import org.apache.spark.sql.catalyst.util.DateTimeUtils
30+
import org.apache.spark.sql.execution.adaptive.QueryFragmentTransformer
3131
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec}
3232
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
3333
import org.apache.spark.sql.internal.SQLConf

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragment.scala

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717

1818
package org.apache.spark.sql.execution.adaptive
1919

20-
import java.util.concurrent.atomic.AtomicBoolean
21-
import java.util.concurrent.{LinkedBlockingDeque, BlockingQueue}
2220
import java.util.{HashMap => JHashMap, Map => JMap}
21+
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}
22+
import java.util.concurrent.atomic.AtomicBoolean
2323

24-
import org.apache.spark.sql.catalyst.rules.Rule
25-
26-
import scala.concurrent.ExecutionContext.Implicits.global
2724
import scala.collection.mutable.ArrayBuffer
25+
import scala.concurrent.ExecutionContext.Implicits.global
2826

29-
import org.apache.spark.{MapOutputStatistics, SimpleFutureAction, ShuffleDependency}
27+
import org.apache.spark.{MapOutputStatistics, ShuffleDependency, SimpleFutureAction}
3028
import org.apache.spark.rdd.RDD
3129
import org.apache.spark.sql.catalyst.InternalRow
3230
import org.apache.spark.sql.catalyst.expressions._
31+
import org.apache.spark.sql.catalyst.plans._
3332
import org.apache.spark.sql.execution.{CollapseCodegenStages, SortExec, SparkPlan}
3433
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
3534
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
@@ -150,7 +149,7 @@ trait QueryFragment extends SparkPlan {
150149
}
151150

152151
case agg @ TungstenAggregate(_, _, _, _, _, _, input @ FragmentInput(_))
153-
if (!input.isOptimized())=> {
152+
if (!input.isOptimized()) => {
154153
logInfo("Begin optimize agg, operator =\n" + agg.toString)
155154
optimizeAggregate(agg, input)
156155
}
@@ -188,7 +187,7 @@ trait QueryFragment extends SparkPlan {
188187
Utils.estimatePartitionStartIndices(aggStatistics.toArray, minNumPostShufflePartitions,
189188
advisoryTargetPostShuffleInputSize)
190189
}
191-
val shuffledRowRdd= childFragments(0).getExchange().preparePostShuffleRDD(
190+
val shuffledRowRdd = childFragments(0).getExchange().preparePostShuffleRDD(
192191
shuffleDependencies(fragmentsIndex.get(childFragments(0))), partitionStartIndices)
193192
childFragments(0).getFragmentInput().setShuffleRdd(shuffledRowRdd)
194193
childFragments(0).getFragmentInput().setOptimized()
@@ -223,7 +222,7 @@ trait QueryFragment extends SparkPlan {
223222

224223
val leftFragment = childFragments(0)
225224
val rightFragment = childFragments(1)
226-
val leftShuffledRowRdd= leftFragment.getExchange().preparePostShuffleRDD(
225+
val leftShuffledRowRdd = leftFragment.getExchange().preparePostShuffleRDD(
227226
shuffleDependencies(fragmentsIndex.get(leftFragment)), partitionStartIndices)
228227
val rightShuffledRowRdd = rightFragment.getExchange().preparePostShuffleRDD(
229228
shuffleDependencies(fragmentsIndex.get(rightFragment)), partitionStartIndices)
@@ -238,7 +237,17 @@ trait QueryFragment extends SparkPlan {
238237
if (sqlContext.conf.autoBroadcastJoinThreshold > 0) {
239238
val leftSizeInBytes = childSizeInBytes(0)
240239
val rightSizeInBytes = childSizeInBytes(1)
241-
if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
240+
val joinType = joinPlan.joinType
241+
def canBuildLeft(joinType: JoinType): Boolean = joinType match {
242+
case Inner | RightOuter => true
243+
case _ => false
244+
}
245+
def canBuildRight(joinType: JoinType): Boolean = joinType match {
246+
case Inner | LeftOuter | LeftSemi | LeftAnti => true
247+
case j: ExistenceJoin => true
248+
case _ => false
249+
}
250+
if (leftSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold && canBuildLeft(joinType)) {
242251
val keys = Utils.rewriteKeyExpr(joinPlan.leftKeys).map(
243252
BindReferences.bindReference(_, left.child.output))
244253
newOperator = BroadcastHashJoinExec(
@@ -249,7 +258,8 @@ trait QueryFragment extends SparkPlan {
249258
joinPlan.condition,
250259
BroadcastExchangeExec(HashedRelationBroadcastMode(keys), left.child),
251260
right.child)
252-
} else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) {
261+
} else if (rightSizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold
262+
&& canBuildRight(joinType)) {
253263
val keys = Utils.rewriteKeyExpr(joinPlan.rightKeys).map(
254264
BindReferences.bindReference(_, right.child.output))
255265
newOperator = BroadcastHashJoinExec(
@@ -260,13 +270,13 @@ trait QueryFragment extends SparkPlan {
260270
joinPlan.condition,
261271
left.child,
262272
BroadcastExchangeExec(HashedRelationBroadcastMode(keys), right.child))
263-
}
273+
}
264274
}
265275
newOperator
266276
}
267277

268278
/** Returns a string representation of the nodes in this tree */
269-
override def treeString: String =
279+
override def treeString: String =
270280
executedPlan.generateTreeString(0, Nil, new StringBuilder).toString
271281

272282
override def simpleString: String = "QueryFragment"

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/utils.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818
package org.apache.spark.sql.execution.adaptive
1919

20+
import scala.collection.mutable.{ArrayBuffer, Queue}
21+
22+
import org.apache.spark.MapOutputStatistics
2023
import org.apache.spark.internal.Logging
2124
import org.apache.spark.sql.catalyst.expressions._
22-
import org.apache.spark.sql.types.{IntegerType, LongType, IntegralType}
23-
import org.apache.spark.MapOutputStatistics
2425
import org.apache.spark.sql.execution.SparkPlan
25-
26-
import scala.collection.mutable.{Queue, ArrayBuffer}
26+
import org.apache.spark.sql.types.{IntegralType, LongType}
2727

2828
/**
2929
* Utility functions used by the query fragment.
@@ -49,7 +49,9 @@ private[sql] object Utils extends Logging {
4949
private[sql] def findLeafFragment(root: QueryFragment): Seq[QueryFragment] = {
5050
val result = new ArrayBuffer[QueryFragment]
5151
val queue = new Queue[QueryFragment]
52-
queue.enqueue(root)
52+
if (!root.children.isEmpty) {
53+
root.children.foreach(c => queue.enqueue(c))
54+
}
5355
while (queue.nonEmpty) {
5456
val current = queue.dequeue()
5557
if (current.children.isEmpty) {

0 commit comments

Comments
 (0)