Skip to content

Commit 38006e7

Browse files
committed
Rewrite EnsureRequirements _yet again_ to make things even simpler
1 parent 0983f75 commit 38006e7

File tree

1 file changed

+25
-64
lines changed

1 file changed

+25
-64
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 25 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -208,66 +208,37 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
208208
}
209209
}
210210

211-
/**
212-
* Return true if all of the operator's children satisfy their output distribution requirements.
213-
*/
214-
private def childPartitioningsSatisfyDistributionRequirements(operator: SparkPlan): Boolean = {
215-
operator.children.zip(operator.requiredChildDistribution).forall {
216-
case (child, distribution) => child.outputPartitioning.satisfies(distribution)
217-
}
218-
}
211+
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
212+
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
213+
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
214+
var children: Seq[SparkPlan] = operator.children
219215

220-
/**
221-
* Given an operator, check whether the operator requires its children to have compatible
222-
* output partitionings and add Exchanges to fix any detected incompatibilities.
223-
*/
224-
private def ensureChildPartitioningsAreCompatible(operator: SparkPlan): SparkPlan = {
225-
// If an operator has multiple children and the operator requires a specific child output
226-
// distribution then we need to ensure that all children have compatible output partitionings.
227-
if (operator.children.length > 1
228-
&& operator.requiredChildDistribution.toSet != Set(UnspecifiedDistribution)) {
229-
if (!Partitioning.allCompatible(operator.children.map(_.outputPartitioning))) {
230-
val newChildren = operator.children.zip(operator.requiredChildDistribution).map {
231-
case (child, requiredDistribution) =>
232-
val targetPartitioning = canonicalPartitioning(requiredDistribution)
233-
if (child.outputPartitioning.guarantees(targetPartitioning)) {
234-
child
235-
} else {
236-
Exchange(targetPartitioning, child)
237-
}
238-
}
239-
val newOperator = operator.withNewChildren(newChildren)
240-
assert(childPartitioningsSatisfyDistributionRequirements(newOperator))
241-
newOperator
216+
// Ensure that the operator's children satisfy their output distribution requirements:
217+
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
218+
if (child.outputPartitioning.satisfies(distribution)) {
219+
child
242220
} else {
243-
operator
221+
Exchange(canonicalPartitioning(distribution), child)
244222
}
245-
} else {
246-
operator
247223
}
248-
}
249-
250-
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
251224

252-
def addShuffleIfNecessary(child: SparkPlan, requiredDistribution: Distribution): SparkPlan = {
253-
// A pre-condition of ensureDistributionAndOrdering is that joins' children have compatible
254-
// partitionings. Thus, we only need to check whether the output partitionings satisfy
255-
// the required distribution. In the case where the children are all compatible, then they
256-
// will either all satisfy the required distribution or will all fail to satisfy it, since
257-
// A.guarantees(B) implies that A and B satisfy the same set of distributions.
258-
// Therefore, if all children are compatible then either all or none of them will shuffled to
259-
// ensure that the distribution requirements are met.
260-
//
261-
// Note that this reasoning implicitly assumes that operators which require compatible
262-
// child partitionings have equivalent required distributions for those children.
263-
if (child.outputPartitioning.satisfies(requiredDistribution)) {
264-
child
265-
} else {
266-
Exchange(canonicalPartitioning(requiredDistribution), child)
225+
// If the operator has multiple children and specifies child output distributions (e.g. join),
226+
// then the children's output partitionings must be compatible:
227+
if (children.length > 1
228+
&& requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
229+
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
230+
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
231+
val targetPartitioning = canonicalPartitioning(distribution)
232+
if (child.outputPartitioning.guarantees(targetPartitioning)) {
233+
child
234+
} else {
235+
Exchange(targetPartitioning, child)
236+
}
267237
}
268238
}
269239

270-
def addSortIfNecessary(child: SparkPlan, requiredOrdering: Seq[SortOrder]): SparkPlan = {
240+
// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
241+
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
271242
if (requiredOrdering.nonEmpty) {
272243
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
273244
val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
@@ -281,20 +252,10 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
281252
}
282253
}
283254

284-
val children = operator.children
285-
val requiredChildDistribution = operator.requiredChildDistribution
286-
val requiredChildOrdering = operator.requiredChildOrdering
287-
assert(children.length == requiredChildDistribution.length)
288-
assert(children.length == requiredChildOrdering.length)
289-
val newChildren = (children, requiredChildDistribution, requiredChildOrdering).zipped.map {
290-
case (child, requiredDistribution, requiredOrdering) =>
291-
addSortIfNecessary(addShuffleIfNecessary(child, requiredDistribution), requiredOrdering)
292-
}
293-
operator.withNewChildren(newChildren)
255+
operator.withNewChildren(children)
294256
}
295257

296258
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
297-
case operator: SparkPlan =>
298-
ensureDistributionAndOrdering(ensureChildPartitioningsAreCompatible(operator))
259+
case operator: SparkPlan => ensureDistributionAndOrdering(operator)
299260
}
300261
}

0 commit comments

Comments
 (0)