Skip to content

Commit ac17a7c

Browse files
authored
Merge pull request #3 from cloud-fan/help
create partial shuffle reader
2 parents cee1c8c + 4abad37 commit ac17a7c

File tree

7 files changed

+93
-63
lines changed

7 files changed

+93
-63
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ case class AdaptiveSparkPlanExec(
9090
// Here the 'OptimizeSkewedPartitions' rule should be executed
9191
// before 'ReduceNumShufflePartitions', as the skewed partition handled
9292
// in 'OptimizeSkewedPartitions' rule, should be omitted in 'ReduceNumShufflePartitions'.
93-
OptimizeSkewedPartitions(conf),
93+
OptimizeSkewedJoin(conf),
9494
ReduceNumShufflePartitions(conf),
9595
// The rule of 'OptimizeLocalShuffleReader' need to make use of the 'partitionStartIndices'
9696
// in 'ReduceNumShufflePartitions' rule. So it must be after 'ReduceNumShufflePartitions' rule.

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,9 @@ case class LocalShuffleReaderExec(
165165
// before shuffle.
166166
if (partitionStartIndicesPerMapper.forall(_.length == 1)) {
167167
child match {
168-
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec, _) =>
168+
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
169169
s.child.outputPartitioning
170-
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec), _) =>
170+
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
171171
s.child.outputPartitioning match {
172172
case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
173173
case other => other
Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ import org.apache.spark.sql.catalyst.plans._
2828
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
2929
import org.apache.spark.sql.catalyst.rules.Rule
3030
import org.apache.spark.sql.execution._
31+
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3132
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
3233
import org.apache.spark.sql.internal.SQLConf
3334

34-
case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
35+
case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
3536

3637
private val supportedJoinTypes =
3738
Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil
@@ -115,8 +116,8 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
115116

116117
def handleSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
117118
case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
118-
SortExec(_, _, left: ShuffleQueryStageExec, _),
119-
SortExec(_, _, right: ShuffleQueryStageExec, _))
119+
s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _),
120+
s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _))
120121
if supportedJoinTypes.contains(joinType) =>
121122
val leftStats = getStatistics(left)
122123
val rightStats = getStatistics(right)
@@ -166,26 +167,20 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
166167
}
167168
// TODO: we may can optimize the sort merge join to broad cast join after
168169
// obtaining the raw data size of per partition,
169-
val leftSkewedReader = SkewedShufflePartitionReader(
170+
val leftSkewedReader = SkewedPartitionReaderExec(
170171
left, partitionId, leftMapIdStartIndices(i), leftEndMapId)
171-
val leftSort = smj.left.asInstanceOf[SortExec].copy(child = leftSkewedReader)
172-
173-
val rightSkewedReader = SkewedShufflePartitionReader(right, partitionId,
174-
rightMapIdStartIndices(j), rightEndMapId)
175-
val rightSort = smj.right.asInstanceOf[SortExec].copy(child = rightSkewedReader)
176-
subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
177-
leftSort, rightSort)
172+
val rightSkewedReader = SkewedPartitionReaderExec(right, partitionId,
173+
rightMapIdStartIndices(j), rightEndMapId)
174+
subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
175+
s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader))
178176
}
179177
}
180178
}
181179
logDebug(s"number of skewed partitions is ${skewedPartitions.size}")
182180
if (skewedPartitions.nonEmpty) {
183181
val optimizedSmj = smj.transformDown {
184182
case sort @ SortExec(_, _, shuffleStage: ShuffleQueryStageExec, _) =>
185-
val newStage = shuffleStage.copy(
186-
excludedPartitions = skewedPartitions.toSet)
187-
newStage.resultOption = shuffleStage.resultOption
188-
sort.copy(child = newStage)
183+
sort.copy(child = PartialShuffleReaderExec(shuffleStage, skewedPartitions.toSet))
189184
}
190185
subJoins += optimizedSmj
191186
UnionExec(subJoins)
@@ -221,15 +216,15 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
221216
/**
222217
* A wrapper of shuffle query stage, which submits one reduce task to read a single
223218
* shuffle partition 'partitionIndex' produced by the mappers in range [startMapIndex, endMapIndex).
224-
* This is used to handle the skewed partitions.
219+
* This is used to increase the parallelism when reading skewed partitions.
225220
*
226221
* @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange
227222
* node during canonicalization.
228223
* @param partitionIndex The pre shuffle partition index.
229224
* @param startMapIndex The start map index.
230225
* @param endMapIndex The end map index.
231226
*/
232-
case class SkewedShufflePartitionReader(
227+
case class SkewedPartitionReaderExec(
233228
child: QueryStageExec,
234229
partitionIndex: Int,
235230
startMapIndex: Int,
@@ -242,10 +237,6 @@ case class SkewedShufflePartitionReader(
242237
}
243238
private var cachedSkewedShuffleRDD: SkewedShuffledRowRDD = null
244239

245-
override def nodeName: String = s"SkewedShuffleReader SkewedShuffleQueryStage: ${child}" +
246-
s" SkewedPartition: ${partitionIndex} startMapIndex: ${startMapIndex}" +
247-
s" endMapIndex: ${endMapIndex}"
248-
249240
override def doExecute(): RDD[InternalRow] = {
250241
if (cachedSkewedShuffleRDD == null) {
251242
cachedSkewedShuffleRDD = child match {
@@ -258,3 +249,45 @@ case class SkewedShufflePartitionReader(
258249
cachedSkewedShuffleRDD
259250
}
260251
}
252+
253+
/**
254+
* A wrapper of shuffle query stage, which skips some partitions when reading the shuffle blocks.
255+
*
256+
* @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during
257+
* canonicalization.
258+
* @param excludedPartitions The partitions to skip when reading.
259+
*/
260+
case class PartialShuffleReaderExec(
261+
child: QueryStageExec,
262+
excludedPartitions: Set[Int]) extends UnaryExecNode {
263+
264+
override def output: Seq[Attribute] = child.output
265+
266+
override def outputPartitioning: Partitioning = {
267+
UnknownPartitioning(1)
268+
}
269+
270+
private def shuffleExchange(): ShuffleExchangeExec = child match {
271+
case stage: ShuffleQueryStageExec => stage.shuffle
272+
case _ =>
273+
throw new IllegalStateException("operating on canonicalization plan")
274+
}
275+
276+
private def getPartitionIndexRanges(): Array[(Int, Int)] = {
277+
val length = shuffleExchange().shuffleDependency.partitioner.numPartitions
278+
(0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1)).toArray
279+
}
280+
281+
private var cachedShuffleRDD: RDD[InternalRow] = null
282+
283+
override def doExecute(): RDD[InternalRow] = {
284+
if (cachedShuffleRDD == null) {
285+
cachedShuffleRDD = if (excludedPartitions.isEmpty) {
286+
child.execute()
287+
} else {
288+
shuffleExchange().createShuffledRDD(Some(getPartitionIndexRanges()))
289+
}
290+
}
291+
cachedShuffleRDD
292+
}
293+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.execution.adaptive
1919

20-
import scala.collection.mutable.ArrayBuffer
2120
import scala.concurrent.Future
2221

2322
import org.apache.spark.{FutureAction, MapOutputStatistics}
@@ -135,8 +134,7 @@ abstract class QueryStageExec extends LeafExecNode {
135134
*/
136135
case class ShuffleQueryStageExec(
137136
override val id: Int,
138-
override val plan: SparkPlan,
139-
val excludedPartitions: Set[Int] = Set.empty) extends QueryStageExec {
137+
override val plan: SparkPlan) extends QueryStageExec {
140138

141139
@transient val shuffle = plan match {
142140
case s: ShuffleExchangeExec => s
@@ -163,26 +161,6 @@ case class ShuffleQueryStageExec(
163161
case _ =>
164162
}
165163
}
166-
167-
private def getPartitionIndexRanges(): Array[(Int, Int)] = {
168-
val length = shuffle.shuffleDependency.partitioner.numPartitions
169-
(0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1)).toArray
170-
}
171-
172-
private var cachedShuffleRDD: RDD[InternalRow] = null
173-
174-
override def doExecute(): RDD[InternalRow] = {
175-
if (cachedShuffleRDD == null) {
176-
cachedShuffleRDD = excludedPartitions match {
177-
case e if e.isEmpty =>
178-
plan.execute()
179-
case _ =>
180-
shuffle.createShuffledRDD(
181-
Some(getPartitionIndexRanges()))
182-
}
183-
}
184-
cachedShuffleRDD
185-
}
186164
}
187165

188166
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.adaptive
1919

20-
import scala.collection.mutable.ArrayBuffer
20+
import scala.collection.mutable.{ArrayBuffer, HashSet}
2121

2222
import org.apache.spark.MapOutputStatistics
2323
import org.apache.spark.rdd.RDD
@@ -54,22 +54,28 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
5454
if (!conf.reducePostShufflePartitionsEnabled) {
5555
return plan
5656
}
57-
// we need skip the leaf node of 'SkewedShufflePartitionReader'
58-
val leafNodes = plan.collectLeaves().filter(!_.isInstanceOf[SkewedShufflePartitionReader])
57+
// 'SkewedShufflePartitionReader' is added by us, so it's safe to ignore it when changing
58+
// number of reducers.
59+
val leafNodes = plan.collectLeaves().filter(!_.isInstanceOf[SkewedPartitionReaderExec])
5960
if (!leafNodes.forall(_.isInstanceOf[QueryStageExec])) {
6061
// If not all leaf nodes are query stages, it's not safe to reduce the number of
6162
// shuffle partitions, because we may break the assumption that all children of a spark plan
6263
// have same number of output partitions.
6364
return plan
6465
}
6566

66-
def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
67+
def collectShuffles(plan: SparkPlan): Seq[SparkPlan] = plan match {
6768
case _: LocalShuffleReaderExec => Nil
69+
case p: PartialShuffleReaderExec => Seq(p)
6870
case stage: ShuffleQueryStageExec => Seq(stage)
69-
case _ => plan.children.flatMap(collectShuffleStages)
71+
case _ => plan.children.flatMap(collectShuffles)
7072
}
7173

72-
val shuffleStages = collectShuffleStages(plan)
74+
val shuffles = collectShuffles(plan)
75+
val shuffleStages = shuffles.map {
76+
case PartialShuffleReaderExec(s: ShuffleQueryStageExec, _) => s
77+
case s: ShuffleQueryStageExec => s
78+
}
7379
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
7480
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
7581
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
@@ -88,18 +94,31 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
8894
// partition) and a result of a SortMergeJoin (multiple partitions).
8995
val distinctNumPreShufflePartitions =
9096
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
91-
val distinctExcludedPartitions = shuffleStages.map(_.excludedPartitions).distinct
97+
val distinctExcludedPartitions = shuffles.map {
98+
case PartialShuffleReaderExec(_, excludedPartitions) => excludedPartitions
99+
case _: ShuffleQueryStageExec => Set.empty[Int]
100+
}.distinct
92101
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1
93102
&& distinctExcludedPartitions.length == 1) {
94-
val excludedPartitions = shuffleStages.head.excludedPartitions
103+
val excludedPartitions = distinctExcludedPartitions.head
95104
val partitionIndices = estimatePartitionStartAndEndIndices(
96105
validMetrics.toArray, excludedPartitions)
97106
// This transformation adds new nodes, so we must use `transformUp` here.
98-
plan.transformUp {
99-
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
100-
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
101-
// number of output partitions.
102-
case stage: ShuffleQueryStageExec =>
107+
// Even for shuffle exchange whose input RDD has 0 partition, we should still update its
108+
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
109+
// number of output partitions.
110+
val visitedStages = HashSet.empty[Int]
111+
plan.transformDown {
112+
// Replace `PartialShuffleReaderExec` with `CoalescedShuffleReaderExec`, which keeps the
113+
// "excludedPartition" requirement and also merges some partitions.
114+
case PartialShuffleReaderExec(stage: ShuffleQueryStageExec, _) =>
115+
visitedStages.add(stage.id)
116+
CoalescedShuffleReaderExec(stage, partitionIndices)
117+
118+
// We are doing `transformDown`, so the `ShuffleQueryStageExec` may already be optimized
119+
// and wrapped by `CoalescedShuffleReaderExec`.
120+
case stage: ShuffleQueryStageExec if !visitedStages.contains(stage.id) =>
121+
visitedStages.add(stage.id)
103122
CoalescedShuffleReaderExec(stage, partitionIndices)
104123
}
105124
} else {

sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
533533
val finalPlan = resultDf.queryExecution.executedPlan
534534
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
535535
assert(finalPlan.collect {
536-
case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r
536+
case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
537537
}.length == 2)
538538
assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 3)
539539

@@ -566,7 +566,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
566566

567567
val reusedStages = level1Stages.flatMap { stage =>
568568
stage.plan.collect {
569-
case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r
569+
case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
570570
}
571571
}
572572
assert(reusedStages.length == 1)

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class AdaptiveQueryExecSuite
9292

9393
private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
9494
collectInPlanAndSubqueries(plan) {
95-
case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
95+
case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
9696
case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e
9797
}
9898
}

0 commit comments

Comments
 (0)