Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ case class Exchange(
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
private[sql] def prepareShuffleDependency(newPartitioning: Partitioning): ShuffleDependency[Int,
InternalRow, InternalRow] = {
val rdd = child.execute()
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
Expand Down Expand Up @@ -251,7 +252,7 @@ case class Exchange(
assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
shuffleRDD
case None =>
val shuffleDependency = prepareShuffleDependency()
val shuffleDependency = prepareShuffleDependency(newPartitioning)
preparePostShuffleRDD(shuffleDependency)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ private[sql] class ExchangeCoordinator(
var i = 0
while (i < numExchanges) {
val exchange = exchanges(i)
val shuffleDependency = exchange.prepareShuffleDependency()
val shuffleDependency = exchange.prepareShuffleDependency(exchange.newPartitioning)
shuffleDependencies += shuffleDependency
if (shuffleDependency.rdd.partitions.length != 0) {
// submitMapStage does not accept RDD with 0 partition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
Expand Down Expand Up @@ -235,6 +236,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
Array(bytesByPartitionId1, bytesByPartitionId2),
expectedPartitionStartIndices)
}

}

///////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -280,6 +282,44 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
try f(sqlContext) finally sparkContext.stop()
}

/**
* SPARK-19462
*/
test("exchange resilience") {
val test: SQLContext => Unit = { sqlContext: SQLContext =>
val df1 =
sqlContext
.range(0, 1000, 1, numInputPartitions)
.selectExpr("id as number")
df1.registerTempTable("test")

val data2 = sqlContext.sql("SELECT number, count(*) cnt FROM test GROUP BY number")
case class RddTree(rdd: RDD[_]) {
val children: Seq[RddTree] = rdd.dependencies.map(x => RddTree(x.rdd))
val rdds: Seq[RDD[_]] = rdd +: children.flatMap(x => x.rdds)
}
// execute the top rdd
data2.collect
// traverse the rdd dependencies and execute from top to bottom
// so that all the depending RDDs are re-executed and UnknownPartitioning should
// not be thrown
RddTree(data2.rdd).rdds.foreach(_.collect)

}

val sparkConf =
new SparkConf(false)
.setMaster("local[*]")
.setAppName("test")
.set("spark.ui.enabled", "false")
.set("spark.driver.allowMultipleContexts", "true")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
val sparkContext = new SparkContext(sparkConf)
val sqlContext = new TestSQLContext(sparkContext)
try test(sqlContext) finally sparkContext.stop()
}


Seq(Some(5), None).foreach { minNumPostShufflePartitions =>
val testNameNote = minNumPostShufflePartitions match {
case Some(numPartitions) => "(minNumPostShufflePartitions: 3)"
Expand Down