Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ object OptimizeLocalShuffleReader {
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match {
case s: ShuffleQueryStageExec =>
s.shuffle.canChangeNumPartitions
// This CustomShuffleReaderExec used in skew side, its numPartitions increased.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means the rule of OptimizeLocalShuffleReader is disabled when enable the rule of OptimizedSkwedJoin rule ?

Copy link
Contributor Author

@LantaoJin LantaoJin Jul 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. In this more general skew join handling, we can match more patterns. For example, we can handle skew join like https://user-images.githubusercontent.com/1853780/87743215-01e9e780-c81b-11ea-97d9-f274b379912e.png. The number partitions of CustomShuffleReader in the the BCJ (changed from SMJ by AE) after OptimizeLocalShuffleReader is not equals to the anther side. So simply, I disable createLocalReader.

case CustomShuffleReaderExec(_, partitionSpecs)
if partitionSpecs.exists(_.isInstanceOf[PartialReducerPartitionSpec]) => false
// This CustomShuffleReaderExec used in non-skew side, its numPartitions equals to
// the skew side CustomShuffleReaderExec.
case CustomShuffleReaderExec(_, partitionSpecs) if partitionSpecs.size > 1 &&
partitionSpecs.forall(_.isInstanceOf[CoalescedPartitionSpec]) &&
partitionSpecs.toSet.size != partitionSpecs.size => false
case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) =>
s.shuffle.canChangeNumPartitions && partitionSpecs.nonEmpty
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils

import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
Expand Down Expand Up @@ -130,20 +131,45 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
}
}

private def canSplitLeftSide(joinType: JoinType) = {
joinType == Inner || joinType == Cross || joinType == LeftSemi ||
joinType == LeftAnti || joinType == LeftOuter
private def canSplitLeftSide(joinType: JoinType, plan: SparkPlan) = {
(joinType == Inner || joinType == Cross || joinType == LeftSemi ||
joinType == LeftAnti || joinType == LeftOuter) && allUnspecifiedDistribution(plan)
}

private def canSplitRightSide(joinType: JoinType) = {
joinType == Inner || joinType == Cross || joinType == RightOuter
private def canSplitRightSide(joinType: JoinType, plan: SparkPlan) = {
(joinType == Inner || joinType == Cross ||
joinType == RightOuter) && allUnspecifiedDistribution(plan)
}

// Check if there is a node in the tree that the requiredChildDistribution is specified,
// other than UnspecifiedDistribution.
private def allUnspecifiedDistribution(plan: SparkPlan): Boolean = plan.find { p =>
p.requiredChildDistribution.exists {
case UnspecifiedDistribution => false
case _ => true
}
}.isEmpty

private def getSizeInfo(medianSize: Long, sizes: Seq[Long]): String = {
s"median size: $medianSize, max size: ${sizes.max}, min size: ${sizes.min}, avg size: " +
sizes.sum / sizes.length
}

private def findShuffleStage(plan: SparkPlan): Option[ShuffleStageInfo] = {
plan collectFirst {
case _ @ ShuffleStage(shuffleStageInfo) =>
shuffleStageInfo
}
}

private def replaceSkewedShufleReader(
smj: SparkPlan, newCtm: CustomShuffleReaderExec): SparkPlan = {
smj transformUp {
case _ @ CustomShuffleReaderExec(child, _) if child.sameResult(newCtm.child) =>
newCtm
}
}

/*
* This method aim to optimize the skewed join with the following steps:
* 1. Check whether the shuffle partition is skewed based on the median size
Expand All @@ -157,96 +183,106 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
* 3 tasks separately.
*/
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(_, _, joinType, _,
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
case smj @ SortMergeJoinExec(_, _, joinType, _, s1: SortExec, s2: SortExec, _)
if supportedJoinTypes.contains(joinType) =>
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// We use the median size of the original shuffle partitions to detect skewed partitions.
val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
|Right side partitions size info:
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
// the final data distribution is even (coalesced partitions + split partitions).
val leftActualSizes = left.partitionsWithSizes.map(_._2)
val rightActualSizes = right.partitionsWithSizes.map(_._2)
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)

val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
var numSkewedLeft = 0
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftActualSize = leftActualSizes(partitionIndex)
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex

val rightActualSize = rightActualSizes(partitionIndex)
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex

// A skewed partition should never be coalesced, but skip it here just to be safe.
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedLeft += 1
// find the shuffleStage from the plan tree
val leftOpt = findShuffleStage(s1)
val rightOpt = findShuffleStage(s2)
if (leftOpt.isEmpty || rightOpt.isEmpty) {
smj
} else {
val left = leftOpt.get
val right = rightOpt.get
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// We use the median size of the original shuffle partitions to detect skewed partitions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is very hard to reason about. We need to clearly define:

  1. what nodes can appear between the shuffle stage and SMJ. As we discussed before, Agg can't appear at the skew side.
  2. how to estimate the size? Since there are nodes in the middle, the stats of the shuffle stage may not be accurate for the final join child. (e.g. Filter in the middle)

Copy link
Contributor Author

@LantaoJin LantaoJin Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. what nodes can appear between the shuffle stage and SMJ. As we discussed before, Agg can't appear at the skew side.

In the canSplitLeftSide and canSplitRightSide, I added a allUnspecifiedDistribution(plan) check. Current we only support the nodes with UnspecifiedDistribution.

  1. how to estimate the size? Since there are nodes in the middle, the stats of the shuffle stage may not be accurate for the final join child. (e.g. Filter in the middle)

Filter should be pushdown to leaf, I didn't see this user case. Project may be a command case in the middle? Yes. the input size of shuffle stage may not be accurate. But the disadvantage is launching more tasks. I think the benefit from handling the skewing is more important than the disadvantage.

val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}

|Right side partitio

|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType, s1)
val canSplitRight = canSplitRightSide(joinType, s2)
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
// the final data distribution is even (coalesced partitions + split partitions).
val leftActualSizes = left.partitionsWithSizes.map(_._2)
val rightActualSizes = right.partitionsWithSizes.map(_._2)
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)

val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
var numSkewedLeft = 0
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftActualSize = leftActualSizes(partitionIndex)
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex

val rightActualSize = rightActualSizes(partitionIndex)
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex

// A skewed partition should never be coalesced, but skip it here just to be safe.
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedLeft += 1
}
skewSpecs.getOrElse(Seq(leftPartSpec))
} else {
Seq(leftPartSpec)
}
skewSpecs.getOrElse(Seq(leftPartSpec))
} else {
Seq(leftPartSpec)
}

// A skewed partition should never be coalesced, but skip it here just to be safe.
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedRight += 1
// A skewed partition should never be coalesced, but skip it here just to be safe.
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedRight += 1
}
skewSpecs.getOrElse(Seq(rightPartSpec))
} else {
Seq(rightPartSpec)
}
skewSpecs.getOrElse(Seq(rightPartSpec))
} else {
Seq(rightPartSpec)
}

for {
leftSidePartition <- leftParts
rightSidePartition <- rightParts
} {
leftSidePartitions += leftSidePartition
rightSidePartitions += rightSidePartition
for {
leftSidePartition <- leftParts
rightSidePartition <- rightParts
} {
leftSidePartitions += leftSidePartition
rightSidePartitions += rightSidePartition
}
}
}

logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions.toSeq)
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions.toSeq)
smj.copy(
left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
} else {
smj
logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions.toSeq)
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions.toSeq)
val newSmj = replaceSkewedShufleReader(
replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec]
newSmj.copy(isSkewJoin = true)
} else {
smj
}
}
}

Expand All @@ -263,18 +299,31 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val shuffleStages = collectShuffleStages(plan)

if (shuffleStages.length == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not we break this limitation first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this PR is not to address the case which has multiple SMJ. We have another PR to change this limitation:

  1. optimizeSingleStageSkewJoin. This is the case one table is a bucket table and the SMJ is bucketing join with one side shuffle and skewing
  2. optimizeThreeShuffleStageSkewJoin. This is to address three tables SMJ (Two SMJs in one stage and no one can be changed to BCJ in AQE).

// When multi table join, there will be too many complex combination to consider.
// Currently we only handle 2 table join like following use case.
// SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes,
// includes such as BroadcastHashJoinExec. So it can handle more than two tables join.
// SMJ
// Sort
// Shuffle
// ..
// Shuffle
// Sort
// Shuffle
// ..
// Shuffle
val optimizePlan = optimizeSkewJoin(plan)
val numShuffles = ensureRequirements.apply(optimizePlan).collect {
case e: ShuffleExchangeExec => e
}.length

def countAdditionalShuffleInAncestorsOfSkewJoin(optimizePlan: SparkPlan): Int = {
val newPlan = ensureRequirements.apply(optimizePlan)
val totalAdditionalShuffles = newPlan.collect { case e: ShuffleExchangeExec => e }.size
val numShufflesFromDescendants =
newPlan.collectFirst { case j: SortMergeJoinExec if j.isSkewJoin => j }.map { smj =>
smj.collect { case e: ShuffleExchangeExec => e }.size
}.getOrElse(0)
totalAdditionalShuffles - numShufflesFromDescendants
}

// Check if we introduced new shuffles in the ancestors of the skewed join operator.
// And we don't care if new shuffles are introduced in the descendants of the join operator,
// since they will not actually be executed in the current adaptive execution framework.
val numShuffles = countAdditionalShuffleInAncestorsOfSkewJoin(optimizePlan)
if (numShuffles > 0) {
logDebug("OptimizeSkewedJoin rule is not applied due" +
" to additional shuffles will be introduced.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ case class SortMergeJoinExec(
}
}

override def outputPartitioning: Partitioning = {
if (isSkewJoin) {
UnknownPartitioning(0)
} else {
super.outputPartitioning
}
}

override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
case _: InnerLike =>
Expand Down
Loading