Skip to content

Commit bc241dd

Browse files
Yucai YuJkSelf
authored andcommitted
In BHJ, shuffle read should be local always (apache#53)
* In BHJ, shuffle read should be local * add comments
1 parent 89406d3 commit bc241dd

File tree

2 files changed

+48
-43
lines changed

2 files changed

+48
-43
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,6 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
8181
private def optimizeForLocalShuffleReadLessPartitions(
8282
broadcastSidePlan: SparkPlan,
8383
childrenPlans: Seq[SparkPlan]) = {
84-
// All shuffle read should be local instead of remote
85-
childrenPlans.foreach {
86-
case input: ShuffleQueryStageInput =>
87-
input.isLocalShuffle = true
88-
case _ =>
89-
}
9084
// If there's shuffle write on broadcast side, then find the partitions with 0 size and ignore
9185
// reading them in local shuffle read.
9286
broadcastSidePlan match {
@@ -138,6 +132,12 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
138132
condition,
139133
removeSort(left),
140134
removeSort(right))
135+
// All shuffle read should be local instead of remote
136+
broadcastJoin.children.foreach {
137+
case input: ShuffleQueryStageInput =>
138+
input.isLocalShuffle = true
139+
case _ =>
140+
}
141141

142142
val newChild = queryStage.child.transformDown {
143143
case s: SortMergeJoinExec if s.fastEquals(smj) => broadcastJoin
@@ -177,11 +177,7 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
177177
} else {
178178
logWarning("Join optimization is not applied due to additional shuffles will be " +
179179
"introduced. Enable spark.sql.adaptive.allowAdditionalShuffle to allow it.")
180-
joinType match {
181-
case _: InnerLike =>
182-
revertShuffleReadChanges(broadcastJoin.children)
183-
case _ =>
184-
}
180+
revertShuffleReadChanges(broadcastJoin.children)
185181
smj
186182
}
187183
}.getOrElse(smj)

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/QueryStageSuite.scala

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,42 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
7272
}
7373
}
7474

75+
def checkJoin(join: DataFrame, spark: SparkSession): Unit = {
76+
// Before Execution, there is one SortMergeJoin
77+
val smjBeforeExecution = join.queryExecution.executedPlan.collect {
78+
case smj: SortMergeJoinExec => smj
79+
}
80+
assert(smjBeforeExecution.length === 1)
81+
82+
// Check the answer.
83+
val expectedAnswer =
84+
spark
85+
.range(0, 1000)
86+
.selectExpr("id % 500 as key", "id as value")
87+
.union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
88+
checkAnswer(
89+
join,
90+
expectedAnswer.collect())
91+
92+
// During execution, the SortMergeJoin is changed to BroadcastHashJoinExec
93+
val smjAfterExecution = join.queryExecution.executedPlan.collect {
94+
case smj: SortMergeJoinExec => smj
95+
}
96+
assert(smjAfterExecution.length === 0)
97+
98+
val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
99+
case smj: BroadcastHashJoinExec => smj
100+
}.length
101+
assert(numBhjAfterExecution === 1)
102+
103+
// Both shuffle should be local shuffle
104+
val queryStageInputs = join.queryExecution.executedPlan.collect {
105+
case q: ShuffleQueryStageInput => q
106+
}
107+
assert(queryStageInputs.length === 2)
108+
assert(queryStageInputs.forall(_.isLocalShuffle) === true)
109+
}
110+
75111
test("1 sort merge join to broadcast join") {
76112
withSparkSession(defaultSparkSession) { spark: SparkSession =>
77113
val df1 =
@@ -83,39 +119,12 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
83119
.range(0, 1000, 1, numInputPartitions)
84120
.selectExpr("id % 500 as key2", "id as value2")
85121

86-
val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2"))
87-
88-
// Before Execution, there is one SortMergeJoin
89-
val smjBeforeExecution = join.queryExecution.executedPlan.collect {
90-
case smj: SortMergeJoinExec => smj
91-
}
92-
assert(smjBeforeExecution.length === 1)
93-
94-
// Check the answer.
95-
val expectedAnswer =
96-
spark
97-
.range(0, 1000)
98-
.selectExpr("id % 500 as key", "id as value")
99-
.union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
100-
checkAnswer(
101-
join,
102-
expectedAnswer.collect())
103-
104-
// During execution, the SortMergeJoin is changed to BroadcastHashJoinExec
105-
val smjAfterExecution = join.queryExecution.executedPlan.collect {
106-
case smj: SortMergeJoinExec => smj
107-
}
108-
assert(smjAfterExecution.length === 0)
122+
val innerJoin = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2"))
123+
checkJoin(innerJoin, spark)
109124

110-
val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
111-
case smj: BroadcastHashJoinExec => smj
112-
}.length
113-
assert(numBhjAfterExecution === 1)
114-
115-
val queryStageInputs = join.queryExecution.executedPlan.collect {
116-
case q: QueryStageInput => q
117-
}
118-
assert(queryStageInputs.length === 2)
125+
val leftJoin =
126+
df1.join(df2, col("key1") === col("key2"), "left").select(col("key1"), col("value1"))
127+
checkJoin(leftJoin, spark)
119128
}
120129
}
121130

0 commit comments

Comments
 (0)