Skip to content

Commit 6e37231

Browse files
JoshRosenCodingCat
authored andcommitted
[SPARK-9703] [SQL] Refactor EnsureRequirements to avoid certain unnecessary shuffles
This pull request refactors the `EnsureRequirements` planning rule in order to avoid the addition of certain unnecessary shuffles. As an example of how unnecessary shuffles can occur, consider SortMergeJoin, which requires clustered distribution and sorted ordering of its children's input rows. Say that both of SMJ's children produce unsorted output but are both SinglePartition. In this case, we will need to inject sort operators but should not need to inject Exchanges. Unfortunately, it looks like the EnsureRequirements unnecessarily repartitions using a hash partitioning. This patch solves this problem by refactoring `EnsureRequirements` to properly implement the `compatibleWith` checks that were broken in earlier implementations. See the significant inline comments for a better description of how this works. The majority of this PR is new comments and test cases, with few actual changes to the code. Author: Josh Rosen <[email protected]> Closes apache#7988 from JoshRosen/exchange-fixes and squashes the following commits: 38006e7 [Josh Rosen] Rewrite EnsureRequirements _yet again_ to make things even simpler 0983f75 [Josh Rosen] More guarantees vs. compatibleWith cleanup; delete BroadcastPartitioning. 8784bd9 [Josh Rosen] Giant comment explaining compatibleWith vs. guarantees 1307c50 [Josh Rosen] Update conditions for requiring child compatibility. 18cddeb [Josh Rosen] Rename DummyPlan to DummySparkPlan. 2c7e126 [Josh Rosen] Merge remote-tracking branch 'origin/master' into exchange-fixes fee65c4 [Josh Rosen] Further refinement to comments / reasoning 642b0bb [Josh Rosen] Further expand comment / reasoning 06aba0c [Josh Rosen] Add more comments 8dbc845 [Josh Rosen] Add even more tests. 4f08278 [Josh Rosen] Fix the test by adding the compatibility check to EnsureRequirements a1c12b9 [Josh Rosen] Add failing test to demonstrate allCompatible bug 0725a34 [Josh Rosen] Small assertion cleanup. 5172ac5 [Josh Rosen] Add test for requiresChildrenToProduceSameNumberOfPartitions. 2e0f33a [Josh Rosen] Write a more generic test for EnsureRequirements. 752b8de [Josh Rosen] style fix c628daf [Josh Rosen] Revert accidental ExchangeSuite change. c9fb231 [Josh Rosen] Rewrite exchange to fix better handle this case. adcc742 [Josh Rosen] Move test to PlannerSuite. 0675956 [Josh Rosen] Preserving ordering and partitioning in row format converters also does not help. cc5669c [Josh Rosen] Adding outputPartitioning to Repartition does not fix the test. 2dfc648 [Josh Rosen] Add failing test illustrating bad exchange planning.
1 parent b51b2dd commit 6e37231

File tree

5 files changed

+328
-65
lines changed

5 files changed

+328
-65
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 112 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,37 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
7575
def clustering: Set[Expression] = ordering.map(_.child).toSet
7676
}
7777

78+
/**
79+
* Describes how an operator's output is split across partitions. The `compatibleWith`,
80+
* `guarantees`, and `satisfies` methods describe relationships between child partitionings,
81+
* target partitionings, and [[Distribution]]s. These relations are described more precisely in
82+
* their individual method docs, but at a high level:
83+
*
84+
* - `satisfies` is a relationship between partitionings and distributions.
85+
* - `compatibleWith` is relationships between an operator's child output partitionings.
86+
* - `guarantees` is a relationship between a child's existing output partitioning and a target
87+
* output partitioning.
88+
*
89+
* Diagrammatically:
90+
*
91+
* +--------------+
92+
* | Distribution |
93+
* +--------------+
94+
* ^
95+
* |
96+
* satisfies
97+
* |
98+
* +--------------+ +--------------+
99+
* | Child | | Target |
100+
* +----| Partitioning |----guarantees--->| Partitioning |
101+
* | +--------------+ +--------------+
102+
* | ^
103+
* | |
104+
* | compatibleWith
105+
* | |
106+
* +------------+
107+
*
108+
*/
78109
sealed trait Partitioning {
79110
/** Returns the number of partitions that the data is split across */
80111
val numPartitions: Int
@@ -90,9 +121,66 @@ sealed trait Partitioning {
90121
/**
91122
* Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
92123
* guarantees the same partitioning scheme described by `other`.
124+
*
125+
* Compatibility of partitionings is only checked for operators that have multiple children
126+
* and that require a specific child output [[Distribution]], such as joins.
127+
*
128+
* Intuitively, partitionings are compatible if they route the same partitioning key to the same
129+
* partition. For instance, two hash partitionings are only compatible if they produce the same
130+
* number of output partitionings and hash records according to the same hash function and
131+
* same partitioning key schema.
132+
*
133+
* Put another way, two partitionings are compatible with each other if they satisfy all of the
134+
* same distribution guarantees.
93135
*/
94-
// TODO: Add an example once we have the `nullSafe` concept.
95-
def guarantees(other: Partitioning): Boolean
136+
def compatibleWith(other: Partitioning): Boolean
137+
138+
/**
139+
* Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees
140+
* the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning
141+
* the child's output according to `B` will be unnecessary. `guarantees` is used as a performance
142+
* optimization to allow the exchange planner to avoid redundant repartitionings. By default,
143+
* a partitioning only guarantees partitionings that are equal to itself (i.e. the same number
144+
* of partitions, same strategy (range or hash), etc).
145+
*
146+
* In order to enable more aggressive optimization, this strict equality check can be relaxed.
147+
* For example, say that the planner needs to repartition all of an operator's children so that
148+
* they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children
149+
* to have the [[SinglePartition]] partitioning. If one of the operator's children already happens
150+
* to be hash-partitioned with a single partition then we do not need to re-shuffle this child;
151+
* this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees`
152+
* [[SinglePartition]].
153+
*
154+
* The SinglePartition example given above is not particularly interesting; guarantees' real
155+
* value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion
156+
* of null-safe partitionings, under which partitionings can specify whether rows whose
157+
* partitioning keys contain null values will be grouped into the same partition or whether they
158+
* will have an unknown / random distribution. If a partitioning does not require nulls to be
159+
* clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered
160+
* partitioning. The converse is not true, however: a partitioning which clusters nulls cannot
161+
* be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a
162+
* symmetric relation.
163+
*
164+
* Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows
165+
* produced by `A` could have also been produced by `B`.
166+
*/
167+
def guarantees(other: Partitioning): Boolean = this == other
168+
}
169+
170+
object Partitioning {
171+
def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
172+
// Note: this assumes transitivity
173+
partitionings.sliding(2).map {
174+
case Seq(a) => true
175+
case Seq(a, b) =>
176+
if (a.numPartitions != b.numPartitions) {
177+
assert(!a.compatibleWith(b) && !b.compatibleWith(a))
178+
false
179+
} else {
180+
a.compatibleWith(b) && b.compatibleWith(a)
181+
}
182+
}.forall(_ == true)
183+
}
96184
}
97185

98186
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -101,6 +189,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
101189
case _ => false
102190
}
103191

192+
override def compatibleWith(other: Partitioning): Boolean = false
193+
104194
override def guarantees(other: Partitioning): Boolean = false
105195
}
106196

@@ -109,21 +199,9 @@ case object SinglePartition extends Partitioning {
109199

110200
override def satisfies(required: Distribution): Boolean = true
111201

112-
override def guarantees(other: Partitioning): Boolean = other match {
113-
case SinglePartition => true
114-
case _ => false
115-
}
116-
}
117-
118-
case object BroadcastPartitioning extends Partitioning {
119-
val numPartitions = 1
202+
override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
120203

121-
override def satisfies(required: Distribution): Boolean = true
122-
123-
override def guarantees(other: Partitioning): Boolean = other match {
124-
case BroadcastPartitioning => true
125-
case _ => false
126-
}
204+
override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1
127205
}
128206

129207
/**
@@ -147,6 +225,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
147225
case _ => false
148226
}
149227

228+
override def compatibleWith(other: Partitioning): Boolean = other match {
229+
case o: HashPartitioning =>
230+
this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
231+
case _ => false
232+
}
233+
150234
override def guarantees(other: Partitioning): Boolean = other match {
151235
case o: HashPartitioning =>
152236
this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
@@ -185,6 +269,11 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
185269
case _ => false
186270
}
187271

272+
override def compatibleWith(other: Partitioning): Boolean = other match {
273+
case o: RangePartitioning => this == o
274+
case _ => false
275+
}
276+
188277
override def guarantees(other: Partitioning): Boolean = other match {
189278
case o: RangePartitioning => this == o
190279
case _ => false
@@ -228,6 +317,13 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
228317
override def satisfies(required: Distribution): Boolean =
229318
partitionings.exists(_.satisfies(required))
230319

320+
/**
321+
* Returns true if any `partitioning` of this collection is compatible with
322+
* the given [[Partitioning]].
323+
*/
324+
override def compatibleWith(other: Partitioning): Boolean =
325+
partitionings.exists(_.compatibleWith(other))
326+
231327
/**
232328
* Returns true if any `partitioning` of this collection guarantees
233329
* the given [[Partitioning]].

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -190,66 +190,72 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
190190
* of input data meets the
191191
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
192192
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
193-
* required input partition ordering requirements are met.
193+
* input partition ordering requirements are met.
194194
*/
195195
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
196196
// TODO: Determine the number of partitions.
197-
def numPartitions: Int = sqlContext.conf.numShufflePartitions
197+
private def numPartitions: Int = sqlContext.conf.numShufflePartitions
198198

199-
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
200-
case operator: SparkPlan =>
201-
// Adds Exchange or Sort operators as required
202-
def addOperatorsIfNecessary(
203-
partitioning: Partitioning,
204-
rowOrdering: Seq[SortOrder],
205-
child: SparkPlan): SparkPlan = {
206-
207-
def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
208-
if (!child.outputPartitioning.guarantees(partitioning)) {
209-
Exchange(partitioning, child)
210-
} else {
211-
child
212-
}
213-
}
199+
/**
200+
* Given a required distribution, returns a partitioning that satisfies that distribution.
201+
*/
202+
private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
203+
requiredDistribution match {
204+
case AllTuples => SinglePartition
205+
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
206+
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
207+
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
208+
}
209+
}
214210

215-
def addSortIfNecessary(child: SparkPlan): SparkPlan = {
216-
217-
if (rowOrdering.nonEmpty) {
218-
// If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
219-
val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
220-
if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
221-
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
222-
} else {
223-
child
224-
}
225-
} else {
226-
child
227-
}
228-
}
211+
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
212+
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
213+
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
214+
var children: Seq[SparkPlan] = operator.children
229215

230-
addSortIfNecessary(addShuffleIfNecessary(child))
216+
// Ensure that the operator's children satisfy their output distribution requirements:
217+
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
218+
if (child.outputPartitioning.satisfies(distribution)) {
219+
child
220+
} else {
221+
Exchange(canonicalPartitioning(distribution), child)
231222
}
223+
}
232224

233-
val requirements =
234-
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
235-
236-
val fixedChildren = requirements.zipped.map {
237-
case (AllTuples, rowOrdering, child) =>
238-
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
239-
case (ClusteredDistribution(clustering), rowOrdering, child) =>
240-
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
241-
case (OrderedDistribution(ordering), rowOrdering, child) =>
242-
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
243-
244-
case (UnspecifiedDistribution, Seq(), child) =>
225+
// If the operator has multiple children and specifies child output distributions (e.g. join),
226+
// then the children's output partitionings must be compatible:
227+
if (children.length > 1
228+
&& requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
229+
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
230+
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
231+
val targetPartitioning = canonicalPartitioning(distribution)
232+
if (child.outputPartitioning.guarantees(targetPartitioning)) {
245233
child
246-
case (UnspecifiedDistribution, rowOrdering, child) =>
247-
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
234+
} else {
235+
Exchange(targetPartitioning, child)
236+
}
237+
}
238+
}
248239

249-
case (dist, ordering, _) =>
250-
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
240+
// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
241+
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
242+
if (requiredOrdering.nonEmpty) {
243+
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
244+
val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
245+
if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
246+
sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child)
247+
} else {
248+
child
249+
}
250+
} else {
251+
child
251252
}
253+
}
252254

253-
operator.withNewChildren(fixedChildren)
255+
operator.withNewChildren(children)
256+
}
257+
258+
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
259+
case operator: SparkPlan => ensureDistributionAndOrdering(operator)
254260
}
255261
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan)
256256
extends UnaryNode {
257257
override def output: Seq[Attribute] = child.output
258258

259+
override def outputPartitioning: Partitioning = {
260+
if (numPartitions == 1) SinglePartition
261+
else UnknownPartitioning(numPartitions)
262+
}
263+
259264
protected override def doExecute(): RDD[InternalRow] = {
260265
child.execute().map(_.copy()).coalesce(numPartitions, shuffle)
261266
}

sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2425
import org.apache.spark.sql.catalyst.rules.Rule
2526

2627
/**
@@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
3334
require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")
3435

3536
override def output: Seq[Attribute] = child.output
37+
override def outputPartitioning: Partitioning = child.outputPartitioning
38+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
3639
override def outputsUnsafeRows: Boolean = true
3740
override def canProcessUnsafeRows: Boolean = false
3841
override def canProcessSafeRows: Boolean = true
@@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
5154
@DeveloperApi
5255
case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
5356
override def output: Seq[Attribute] = child.output
57+
override def outputPartitioning: Partitioning = child.outputPartitioning
58+
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
5459
override def outputsUnsafeRows: Boolean = false
5560
override def canProcessUnsafeRows: Boolean = true
5661
override def canProcessSafeRows: Boolean = false

0 commit comments

Comments
 (0)