Skip to content

Commit 2e0f33a

Browse files
committed
Write a more generic test for EnsureRequirements.
1 parent 752b8de commit 2e0f33a

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.rdd.RDD
2122
import org.apache.spark.sql.TestData._
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.expressions .{Ascending, Literal, Attribute, SortOrder}
2225
import org.apache.spark.sql.catalyst.plans._
2326
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
24-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin}
27+
import org.apache.spark.sql.catalyst.plans.physical._
28+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
2529
import org.apache.spark.sql.functions._
2630
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
2731
import org.apache.spark.sql.test.TestSQLContext._
@@ -203,13 +207,38 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
203207
}
204208
}
205209

206-
test("EnsureRequirements shouldn't add exchange to SMJ inputs if both are SinglePartition") {
207-
val df = (1 to 10).map(Tuple1.apply).toDF("a").repartition(1)
208-
val keys = Seq(df.col("a").expr)
209-
val smj = SortMergeJoin(keys, keys, df.queryExecution.sparkPlan, df.queryExecution.sparkPlan)
210-
val afterEnsureRequirements = EnsureRequirements(df.sqlContext).apply(smj)
211-
if (afterEnsureRequirements.collect { case Exchange(_, _) => true }.nonEmpty) {
212-
fail(s"No Exchanges should have been added:\n$afterEnsureRequirements")
210+
// --- Unit tests of EnsureRequirements ---------------------------------------------------------
211+
212+
test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") {
213+
val outputOrdering = Seq(SortOrder(Literal(1), Ascending))
214+
val distribution = ClusteredDistribution(Literal(1) :: Nil)
215+
val inputPlan = DummyPlan(
216+
children = Seq(
217+
DummyPlan(outputPartitioning = SinglePartition),
218+
DummyPlan(outputPartitioning = SinglePartition)
219+
),
220+
requiresChildrenToProduceSameNumberOfPartitions = true,
221+
requiredChildDistribution = Seq(distribution, distribution),
222+
requiredChildOrdering = Seq(outputOrdering, outputOrdering)
223+
)
224+
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
225+
if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
226+
fail(s"No Exchanges should have been added:\n$outputPlan")
213227
}
214228
}
229+
230+
// ---------------------------------------------------------------------------------------------
231+
}
232+
233+
// Used for unit-testing EnsureRequirements
234+
private case class DummyPlan(
235+
override val children: Seq[SparkPlan] = Nil,
236+
override val outputOrdering: Seq[SortOrder] = Nil,
237+
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
238+
override val requiresChildrenToProduceSameNumberOfPartitions: Boolean = false,
239+
override val requiredChildDistribution: Seq[Distribution] = Nil,
240+
override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil
241+
) extends SparkPlan {
242+
override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError
243+
override def output: Seq[Attribute] = Seq.empty
215244
}

0 commit comments

Comments
 (0)