@@ -208,66 +208,37 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
208208 }
209209 }
210210
211- /**
212- * Return true if all of the operator's children satisfy their output distribution requirements.
213- */
214- private def childPartitioningsSatisfyDistributionRequirements (operator : SparkPlan ): Boolean = {
215- operator.children.zip(operator.requiredChildDistribution).forall {
216- case (child, distribution) => child.outputPartitioning.satisfies(distribution)
217- }
218- }
211+ private def ensureDistributionAndOrdering (operator : SparkPlan ): SparkPlan = {
212+ val requiredChildDistributions : Seq [Distribution ] = operator.requiredChildDistribution
213+ val requiredChildOrderings : Seq [Seq [SortOrder ]] = operator.requiredChildOrdering
214+ var children : Seq [SparkPlan ] = operator.children
219215
220- /**
221- * Given an operator, check whether the operator requires its children to have compatible
222- * output partitionings and add Exchanges to fix any detected incompatibilities.
223- */
224- private def ensureChildPartitioningsAreCompatible (operator : SparkPlan ): SparkPlan = {
225- // If an operator has multiple children and the operator requires a specific child output
226- // distribution then we need to ensure that all children have compatible output partitionings.
227- if (operator.children.length > 1
228- && operator.requiredChildDistribution.toSet != Set (UnspecifiedDistribution )) {
229- if (! Partitioning .allCompatible(operator.children.map(_.outputPartitioning))) {
230- val newChildren = operator.children.zip(operator.requiredChildDistribution).map {
231- case (child, requiredDistribution) =>
232- val targetPartitioning = canonicalPartitioning(requiredDistribution)
233- if (child.outputPartitioning.guarantees(targetPartitioning)) {
234- child
235- } else {
236- Exchange (targetPartitioning, child)
237- }
238- }
239- val newOperator = operator.withNewChildren(newChildren)
240- assert(childPartitioningsSatisfyDistributionRequirements(newOperator))
241- newOperator
216+ // Ensure that the operator's children satisfy their output distribution requirements:
217+ children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
218+ if (child.outputPartitioning.satisfies(distribution)) {
219+ child
242220 } else {
243- operator
221+ Exchange (canonicalPartitioning(distribution), child)
244222 }
245- } else {
246- operator
247223 }
248- }
249-
250- private def ensureDistributionAndOrdering (operator : SparkPlan ): SparkPlan = {
251224
252- def addShuffleIfNecessary (child : SparkPlan , requiredDistribution : Distribution ): SparkPlan = {
253- // A pre-condition of ensureDistributionAndOrdering is that joins' children have compatible
254- // partitionings. Thus, we only need to check whether the output partitionings satisfy
255- // the required distribution. In the case where the children are all compatible, then they
256- // will either all satisfy the required distribution or will all fail to satisfy it, since
257- // A.guarantees(B) implies that A and B satisfy the same set of distributions.
258- // Therefore, if all children are compatible then either all or none of them will shuffled to
259- // ensure that the distribution requirements are met.
260- //
261- // Note that this reasoning implicitly assumes that operators which require compatible
262- // child partitionings have equivalent required distributions for those children.
263- if (child.outputPartitioning.satisfies(requiredDistribution)) {
264- child
265- } else {
266- Exchange (canonicalPartitioning(requiredDistribution), child)
225+ // If the operator has multiple children and specifies child output distributions (e.g. join),
226+ // then the children's output partitionings must be compatible:
227+ if (children.length > 1
228+ && requiredChildDistributions.toSet != Set (UnspecifiedDistribution )
229+ && ! Partitioning .allCompatible(children.map(_.outputPartitioning))) {
230+ children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
231+ val targetPartitioning = canonicalPartitioning(distribution)
232+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
233+ child
234+ } else {
235+ Exchange (targetPartitioning, child)
236+ }
267237 }
268238 }
269239
270- def addSortIfNecessary (child : SparkPlan , requiredOrdering : Seq [SortOrder ]): SparkPlan = {
240+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
241+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
271242 if (requiredOrdering.nonEmpty) {
272243 // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
273244 val minSize = Seq (requiredOrdering.size, child.outputOrdering.size).min
@@ -281,20 +252,10 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
281252 }
282253 }
283254
284- val children = operator.children
285- val requiredChildDistribution = operator.requiredChildDistribution
286- val requiredChildOrdering = operator.requiredChildOrdering
287- assert(children.length == requiredChildDistribution.length)
288- assert(children.length == requiredChildOrdering.length)
289- val newChildren = (children, requiredChildDistribution, requiredChildOrdering).zipped.map {
290- case (child, requiredDistribution, requiredOrdering) =>
291- addSortIfNecessary(addShuffleIfNecessary(child, requiredDistribution), requiredOrdering)
292- }
293- operator.withNewChildren(newChildren)
255+ operator.withNewChildren(children)
294256 }
295257
296258 def apply (plan : SparkPlan ): SparkPlan = plan.transformUp {
297- case operator : SparkPlan =>
298- ensureDistributionAndOrdering(ensureChildPartitioningsAreCompatible(operator))
259+ case operator : SparkPlan => ensureDistributionAndOrdering(operator)
299260 }
300261}
0 commit comments