@@ -26,35 +26,33 @@ import org.apache.spark._
2626
2727class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
2828
29- // TODO we should run this with a matrix of configurations: different shufflers,
30- // external shuffle service, etc. But that is really pushing the question of how to run
31- // such a long test ...
32-
33- ignore(" no concurrent retries for stage attempts (SPARK-7308)" ) {
34- // see SPARK-7308 for a detailed description of the conditions this is trying to recreate.
35- // note that this is somewhat convoluted for a test case, but isn't actually very unusual
36- // under a real workload. We only fail the first attempt of stage 2, but that
37- // could be enough to cause havoc.
38-
39- (0 until 100 ).foreach { idx =>
40- println(new Date () + " \t trial " + idx)
29+ test(" no concurrent retries for stage attempts (SPARK-8103)" ) {
30+ // make sure that if we get fetch failures after the retry has started, we ignore them,
31+ // and so don't end up submitting multiple concurrent attempts for the same stage
32+
33+ (0 until 20 ).foreach { idx =>
4134 logInfo(new Date () + " \t trial " + idx)
4235
4336 val conf = new SparkConf ().set(" spark.executor.memory" , " 100m" )
44- val clusterSc = new SparkContext (" local-cluster[5,4 ,100]" , " test-cluster" , conf)
37+ val clusterSc = new SparkContext (" local-cluster[2,2 ,100]" , " test-cluster" , conf)
4538 val bms = ArrayBuffer [BlockManagerId ]()
4639 val stageFailureCount = HashMap [Int , Int ]()
40+ val stageSubmissionCount = HashMap [Int , Int ]()
4741 clusterSc.addSparkListener(new SparkListener {
4842 override def onBlockManagerAdded (bmAdded : SparkListenerBlockManagerAdded ): Unit = {
4943 bms += bmAdded.blockManagerId
5044 }
5145
46+ override def onStageSubmitted (stageSubmited : SparkListenerStageSubmitted ): Unit = {
47+ val stage = stageSubmited.stageInfo.stageId
48+ stageSubmissionCount(stage) = stageSubmissionCount.getOrElse(stage, 0 ) + 1
49+ }
50+
51+
5252 override def onStageCompleted (stageCompleted : SparkListenerStageCompleted ): Unit = {
5353 if (stageCompleted.stageInfo.failureReason.isDefined) {
5454 val stage = stageCompleted.stageInfo.stageId
5555 stageFailureCount(stage) = stageFailureCount.getOrElse(stage, 0 ) + 1
56- val reason = stageCompleted.stageInfo.failureReason.get
57- println(" stage " + stage + " failed: " + stageFailureCount(stage))
5856 }
5957 }
6058 })
@@ -66,34 +64,37 @@ class DAGSchedulerFailureRecoverySuite extends SparkFunSuite with Logging {
6664 // to avoid broadcast failures
6765 val someBlockManager = bms.filter{! _.isDriver}(0 )
6866
69- val shuffled = rawData.groupByKey(100 ).mapPartitionsWithIndex { case (idx, itr) =>
67+ val shuffled = rawData.groupByKey(20 ).mapPartitionsWithIndex { case (idx, itr) =>
7068 // we want one failure quickly, and more failures after stage 0 has finished its
7169 // second attempt
7270 val stageAttemptId = TaskContext .get().asInstanceOf [TaskContextImpl ].stageAttemptId
7371 if (stageAttemptId == 0 ) {
7472 if (idx == 0 ) {
7573 throw new FetchFailedException (someBlockManager, 0 , 0 , idx,
7674 cause = new RuntimeException (" simulated fetch failure" ))
77- } else if (idx > 0 && math.random < 0.2 ) {
78- Thread .sleep(5000 )
75+ } else if (idx == 1 ) {
76+ Thread .sleep(2000 )
7977 throw new FetchFailedException (someBlockManager, 0 , 0 , idx,
8078 cause = new RuntimeException (" simulated fetch failure" ))
81- } else {
82- // want to make sure plenty of these finish after task 0 fails, and some even finish
83- // after the previous stage is retried and this stage retry is started
84- Thread .sleep((500 + math.random * 5000 ).toLong)
8579 }
80+ } else {
81+ // just to make sure the second attempt doesn't finish before we trigger more failures
82+ // from the first attempt
83+ Thread .sleep(2000 )
8684 }
8785 itr.map { x => ((x._1 + 5 ) % 100 ) -> x._2 }
8886 }
89- val data = shuffled.mapPartitions { itr => itr.flatMap(_._2) }.collect()
87+ val data = shuffled.mapPartitions { itr =>
88+ itr.flatMap(_._2)
89+ }.cache().collect()
9090 val count = data.size
9191 assert(count === 1e6 .toInt)
9292 assert(data.toSet === (1 to 1e6 .toInt).toSet)
9393
9494 assert(stageFailureCount.getOrElse(1 , 0 ) === 0 )
95- assert(stageFailureCount.getOrElse(2 , 0 ) == 1 )
96- assert(stageFailureCount.getOrElse(3 , 0 ) == 0 )
95+ assert(stageFailureCount.getOrElse(2 , 0 ) === 1 )
96+ assert(stageSubmissionCount.getOrElse(1 , 0 ) <= 2 )
97+ assert(stageSubmissionCount.getOrElse(2 , 0 ) === 2 )
9798 } finally {
9899 clusterSc.stop()
99100 }
0 commit comments