@@ -28,10 +28,11 @@ import org.apache.spark.sql.catalyst.plans._
2828import org .apache .spark .sql .catalyst .plans .physical .{Partitioning , UnknownPartitioning }
2929import org .apache .spark .sql .catalyst .rules .Rule
3030import org .apache .spark .sql .execution ._
31+ import org .apache .spark .sql .execution .exchange .ShuffleExchangeExec
3132import org .apache .spark .sql .execution .joins .SortMergeJoinExec
3233import org .apache .spark .sql .internal .SQLConf
3334
34- case class OptimizeSkewedPartitions (conf : SQLConf ) extends Rule [SparkPlan ] {
35+ case class OptimizeSkewedJoin (conf : SQLConf ) extends Rule [SparkPlan ] {
3536
3637 private val supportedJoinTypes =
3738 Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil
@@ -115,8 +116,8 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
115116
116117 def handleSkewJoin (plan : SparkPlan ): SparkPlan = plan.transformUp {
117118 case smj @ SortMergeJoinExec (leftKeys, rightKeys, joinType, condition,
118- SortExec (_, _, left : ShuffleQueryStageExec , _),
119- SortExec (_, _, right : ShuffleQueryStageExec , _))
119+ s1 @ SortExec (_, _, left : ShuffleQueryStageExec , _),
120+ s2 @ SortExec (_, _, right : ShuffleQueryStageExec , _))
120121 if supportedJoinTypes.contains(joinType) =>
121122 val leftStats = getStatistics(left)
122123 val rightStats = getStatistics(right)
@@ -166,26 +167,20 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
166167 }
167168 // TODO: we may can optimize the sort merge join to broad cast join after
168169 // obtaining the raw data size of per partition,
169- val leftSkewedReader = SkewedShufflePartitionReader (
170+ val leftSkewedReader = SkewedPartitionReaderExec (
170171 left, partitionId, leftMapIdStartIndices(i), leftEndMapId)
171- val leftSort = smj.left.asInstanceOf [SortExec ].copy(child = leftSkewedReader)
172-
173- val rightSkewedReader = SkewedShufflePartitionReader (right, partitionId,
174- rightMapIdStartIndices(j), rightEndMapId)
175- val rightSort = smj.right.asInstanceOf [SortExec ].copy(child = rightSkewedReader)
176- subJoins += SortMergeJoinExec (leftKeys, rightKeys, joinType, condition,
177- leftSort, rightSort)
172+ val rightSkewedReader = SkewedPartitionReaderExec (right, partitionId,
173+ rightMapIdStartIndices(j), rightEndMapId)
174+ subJoins += SortMergeJoinExec (leftKeys, rightKeys, joinType, condition,
175+ s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader))
178176 }
179177 }
180178 }
181179 logDebug(s " number of skewed partitions is ${skewedPartitions.size}" )
182180 if (skewedPartitions.nonEmpty) {
183181 val optimizedSmj = smj.transformDown {
184182 case sort @ SortExec (_, _, shuffleStage : ShuffleQueryStageExec , _) =>
185- val newStage = shuffleStage.copy(
186- excludedPartitions = skewedPartitions.toSet)
187- newStage.resultOption = shuffleStage.resultOption
188- sort.copy(child = newStage)
183+ sort.copy(child = PartialShuffleReaderExec (shuffleStage, skewedPartitions.toSet))
189184 }
190185 subJoins += optimizedSmj
191186 UnionExec (subJoins)
@@ -221,15 +216,15 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
221216/**
222217 * A wrapper of shuffle query stage, which submits one reduce task to read a single
223218 * shuffle partition 'partitionIndex' produced by the mappers in range [startMapIndex, endMapIndex).
224- * This is used to handle the skewed partitions.
219+ * This is used to increase the parallelism when reading skewed partitions.
225220 *
226221 * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange
227222 * node during canonicalization.
228223 * @param partitionIndex The pre shuffle partition index.
229224 * @param startMapIndex The start map index.
230225 * @param endMapIndex The end map index.
231226 */
232- case class SkewedShufflePartitionReader (
227+ case class SkewedPartitionReaderExec (
233228 child : QueryStageExec ,
234229 partitionIndex : Int ,
235230 startMapIndex : Int ,
@@ -242,10 +237,6 @@ case class SkewedShufflePartitionReader(
242237 }
243238 private var cachedSkewedShuffleRDD : SkewedShuffledRowRDD = null
244239
245- override def nodeName : String = s " SkewedShuffleReader SkewedShuffleQueryStage: ${child}" +
246- s " SkewedPartition: ${partitionIndex} startMapIndex: ${startMapIndex}" +
247- s " endMapIndex: ${endMapIndex}"
248-
249240 override def doExecute (): RDD [InternalRow ] = {
250241 if (cachedSkewedShuffleRDD == null ) {
251242 cachedSkewedShuffleRDD = child match {
@@ -258,3 +249,45 @@ case class SkewedShufflePartitionReader(
258249 cachedSkewedShuffleRDD
259250 }
260251}
252+
253+ /**
254+ * A wrapper of shuffle query stage, which skips some partitions when reading the shuffle blocks.
255+ *
256+ * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during
257+ * canonicalization.
258+ * @param excludedPartitions The partitions to skip when reading.
259+ */
260+ case class PartialShuffleReaderExec (
261+ child : QueryStageExec ,
262+ excludedPartitions : Set [Int ]) extends UnaryExecNode {
263+
264+ override def output : Seq [Attribute ] = child.output
265+
266+ override def outputPartitioning : Partitioning = {
267+ UnknownPartitioning (1 )
268+ }
269+
270+ private def shuffleExchange (): ShuffleExchangeExec = child match {
271+ case stage : ShuffleQueryStageExec => stage.shuffle
272+ case _ =>
273+ throw new IllegalStateException (" operating on canonicalization plan" )
274+ }
275+
276+ private def getPartitionIndexRanges (): Array [(Int , Int )] = {
277+ val length = shuffleExchange().shuffleDependency.partitioner.numPartitions
278+ (0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1 )).toArray
279+ }
280+
281+ private var cachedShuffleRDD : RDD [InternalRow ] = null
282+
283+ override def doExecute (): RDD [InternalRow ] = {
284+ if (cachedShuffleRDD == null ) {
285+ cachedShuffleRDD = if (excludedPartitions.isEmpty) {
286+ child.execute()
287+ } else {
288+ shuffleExchange().createShuffledRDD(Some (getPartitionIndexRanges()))
289+ }
290+ }
291+ cachedShuffleRDD
292+ }
293+ }
0 commit comments