|
18 | 18 | package org.apache.spark.sql.execution |
19 | 19 |
|
20 | 20 | import org.apache.spark.SparkFunSuite |
| 21 | +import org.apache.spark.rdd.RDD |
21 | 22 | import org.apache.spark.sql.TestData._ |
| 23 | +import org.apache.spark.sql.catalyst.InternalRow |
| 24 | +import org.apache.spark.sql.catalyst.expressions .{Ascending, Literal, Attribute, SortOrder} |
22 | 25 | import org.apache.spark.sql.catalyst.plans._ |
23 | 26 | import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan |
24 | | -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin} |
| 27 | +import org.apache.spark.sql.catalyst.plans.physical._ |
| 28 | +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} |
25 | 29 | import org.apache.spark.sql.functions._ |
26 | 30 | import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} |
27 | 31 | import org.apache.spark.sql.test.TestSQLContext._ |
@@ -203,13 +207,38 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils { |
203 | 207 | } |
204 | 208 | } |
205 | 209 |
|
206 | | - test("EnsureRequirements shouldn't add exchange to SMJ inputs if both are SinglePartition") { |
207 | | - val df = (1 to 10).map(Tuple1.apply).toDF("a").repartition(1) |
208 | | - val keys = Seq(df.col("a").expr) |
209 | | - val smj = SortMergeJoin(keys, keys, df.queryExecution.sparkPlan, df.queryExecution.sparkPlan) |
210 | | - val afterEnsureRequirements = EnsureRequirements(df.sqlContext).apply(smj) |
211 | | - if (afterEnsureRequirements.collect { case Exchange(_, _) => true }.nonEmpty) { |
212 | | - fail(s"No Exchanges should have been added:\n$afterEnsureRequirements") |
| 210 | + // --- Unit tests of EnsureRequirements --------------------------------------------------------- |
| 211 | + |
| 212 | + test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") { |
| 213 | + val outputOrdering = Seq(SortOrder(Literal(1), Ascending)) |
| 214 | + val distribution = ClusteredDistribution(Literal(1) :: Nil) |
| 215 | + val inputPlan = DummyPlan( |
| 216 | + children = Seq( |
| 217 | + DummyPlan(outputPartitioning = SinglePartition), |
| 218 | + DummyPlan(outputPartitioning = SinglePartition) |
| 219 | + ), |
| 220 | + requiresChildrenToProduceSameNumberOfPartitions = true, |
| 221 | + requiredChildDistribution = Seq(distribution, distribution), |
| 222 | + requiredChildOrdering = Seq(outputOrdering, outputOrdering) |
| 223 | + ) |
| 224 | + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) |
| 225 | + if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) { |
| 226 | + fail(s"No Exchanges should have been added:\n$outputPlan") |
213 | 227 | } |
214 | 228 | } |
| 229 | + |
| 230 | + // --------------------------------------------------------------------------------------------- |
| 231 | +} |
| 232 | + |
| 233 | +// Used for unit-testing EnsureRequirements |
| 234 | +private case class DummyPlan( |
| 235 | + override val children: Seq[SparkPlan] = Nil, |
| 236 | + override val outputOrdering: Seq[SortOrder] = Nil, |
| 237 | + override val outputPartitioning: Partitioning = UnknownPartitioning(0), |
| 238 | + override val requiresChildrenToProduceSameNumberOfPartitions: Boolean = false, |
| 239 | + override val requiredChildDistribution: Seq[Distribution] = Nil, |
| 240 | + override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil |
| 241 | + ) extends SparkPlan { |
| 242 | + override protected def doExecute(): RDD[InternalRow] = throw new NotImplementedError |
| 243 | + override def output: Seq[Attribute] = Seq.empty |
215 | 244 | } |
0 commit comments