@@ -101,9 +101,15 @@ class DAGSchedulerSuite
101101  /**  Length of time to wait while draining listener events. */  
102102  val  WAIT_TIMEOUT_MILLIS  =  10000 
103103  val  sparkListener  =  new  SparkListener () {
104+     val  submittedStageInfos  =  new  HashSet [StageInfo ]
104105    val  successfulStages  =  new  HashSet [Int ]
105106    val  failedStages  =  new  ArrayBuffer [Int ]
106107    val  stageByOrderOfExecution  =  new  ArrayBuffer [Int ]
108+ 
109+     override  def  onStageSubmitted (stageSubmitted : SparkListenerStageSubmitted ) {
110+       submittedStageInfos +=  stageSubmitted.stageInfo
111+     }
112+ 
107113    override  def  onStageCompleted (stageCompleted : SparkListenerStageCompleted ) {
108114      val  stageInfo  =  stageCompleted.stageInfo
109115      stageByOrderOfExecution +=  stageInfo.stageId
@@ -150,6 +156,7 @@ class DAGSchedulerSuite
150156    //  Enable local execution for this test
151157    val  conf  =  new  SparkConf ().set(" spark.localExecution.enabled"  , " true"  )
152158    sc =  new  SparkContext (" local"  , " DAGSchedulerSuite"  , conf)
159+     sparkListener.submittedStageInfos.clear()
153160    sparkListener.successfulStages.clear()
154161    sparkListener.failedStages.clear()
155162    failure =  null 
@@ -547,6 +554,133 @@ class DAGSchedulerSuite
547554    assert(sparkListener.failedStages.size ==  1 )
548555  }
549556
557+   /**  This tests the case where another FetchFailed comes in while the map stage is getting 
558+     * re-run. */  
559+   test(" late fetch failures don't cause multiple concurrent attempts for the same map stage"  ) {
560+     val  shuffleMapRdd  =  new  MyRDD (sc, 2 , Nil )
561+     val  shuffleDep  =  new  ShuffleDependency (shuffleMapRdd, null )
562+     val  shuffleId  =  shuffleDep.shuffleId
563+     val  reduceRdd  =  new  MyRDD (sc, 2 , List (shuffleDep))
564+     submit(reduceRdd, Array (0 , 1 ))
565+ 
566+     val  mapStageId  =  0 
567+     def  countSubmittedMapStageAttempts ():  Int  =  {
568+       sparkListener.submittedStageInfos.count(_.stageId ==  mapStageId)
569+     }
570+ 
571+     //  The map stage should have been submitted.
572+     assert(countSubmittedMapStageAttempts() ===  1 )
573+ 
574+     complete(taskSets(0 ), Seq (
575+       (Success , makeMapStatus(" hostA"  , 1 )),
576+       (Success , makeMapStatus(" hostB"  , 1 ))))
577+     //  The MapOutputTracker should know about both map output locations.
578+     assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1.host) === 
579+       Array (" hostA"  , " hostB"  ))
580+ 
581+     //  The first result task fails, with a fetch failure for the output from the first mapper.
582+     runEvent(CompletionEvent (
583+       taskSets(1 ).tasks(0 ),
584+       FetchFailed (makeBlockManagerId(" hostA"  ), shuffleId, 0 , 0 , " ignored"  ),
585+       null ,
586+       Map [Long , Any ](),
587+       createFakeTaskInfo(),
588+       null ))
589+     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
590+     assert(sparkListener.failedStages.contains(1 ))
591+ 
592+     //  Trigger resubmission of the failed map stage.
593+     runEvent(ResubmitFailedStages )
594+     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
595+ 
596+     //  Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
597+     assert(countSubmittedMapStageAttempts() ===  2 )
598+ 
599+     //  The second ResultTask fails, with a fetch failure for the output from the second mapper.
600+     runEvent(CompletionEvent (
601+       taskSets(1 ).tasks(1 ),
602+       FetchFailed (makeBlockManagerId(" hostB"  ), shuffleId, 1 , 1 , " ignored"  ),
603+       null ,
604+       Map [Long , Any ](),
605+       createFakeTaskInfo(),
606+       null ))
607+ 
608+     //  Another ResubmitFailedStages event should not result result in another attempt for the map
609+     //  stage being run concurrently.
610+     runEvent(ResubmitFailedStages )
611+     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
612+     assert(countSubmittedMapStageAttempts() ===  2 )
613+ 
614+     //  NOTE: the actual ResubmitFailedStages may get called at any time during this, shouldn't effect anything --
615+     //  our calling it just makes *SURE* it gets called between the desired event and our check.
616+ 
617+   }
618+ 
619+   /**  This tests the case where a late FetchFailed comes in after the map stage has finished getting 
620+     * retried and a new reduce stage starts running. 
621+     */  
622+   test(" extremely late fetch failures don't cause multiple concurrent attempts for the same stage"  ) {
623+     val  shuffleMapRdd  =  new  MyRDD (sc, 2 , Nil )
624+     val  shuffleDep  =  new  ShuffleDependency (shuffleMapRdd, null )
625+     val  shuffleId  =  shuffleDep.shuffleId
626+     val  reduceRdd  =  new  MyRDD (sc, 2 , List (shuffleDep))
627+     submit(reduceRdd, Array (0 , 1 ))
628+ 
629+     def  countSubmittedReduceStageAttempts ():  Int  =  {
630+       sparkListener.submittedStageInfos.count(_.stageId ==  1 )
631+     }
632+     def  countSubmittedMapStageAttempts ():  Int  =  {
633+       sparkListener.submittedStageInfos.count(_.stageId ==  0 )
634+     }
635+ 
636+     //  The map stage should have been submitted.
637+     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
638+     assert(countSubmittedMapStageAttempts() ===  1 )
639+ 
640+     //  Complete the map stage.
641+     complete(taskSets(0 ), Seq (
642+       (Success , makeMapStatus(" hostA"  , 1 )),
643+       (Success , makeMapStatus(" hostB"  , 1 ))))
644+ 
645+     //  The reduce stage should have been submitted.
646+     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
647+     assert(countSubmittedReduceStageAttempts() ===  1 )
648+ 
649+     //  The first result task fails, with a fetch failure for the output from the first mapper.
650+     runEvent(CompletionEvent (
651+       taskSets(1 ).tasks(0 ),
652+       FetchFailed (makeBlockManagerId(" hostA"  ), shuffleId, 0 , 0 , " ignored"  ),
653+       null ,
654+       Map [Long , Any ](),
655+       createFakeTaskInfo(),
656+       null ))
657+ 
658+     //  Trigger resubmission of the failed map stage and finish the re-started map task.
659+     runEvent(ResubmitFailedStages )
660+     complete(taskSets(2 ), Seq ((Success , makeMapStatus(" hostA"  , 1 ))))
661+ 
662+     //  Because the map stage finished, another attempt for the reduce stage should have been
663+     //  submitted, resulting in 2 total attempts for each the map and the reduce stage.
664+     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS ))
665+     assert(countSubmittedMapStageAttempts() ===  2 )
666+     assert(countSubmittedReduceStageAttempts() ===  2 )
667+ 
668+     //  A late FetchFailed arrives from the second task in the original reduce stage.
669+     runEvent(CompletionEvent (
670+       taskSets(1 ).tasks(1 ),
671+       FetchFailed (makeBlockManagerId(" hostB"  ), shuffleId, 1 , 1 , " ignored"  ),
672+       null ,
673+       Map [Long , Any ](),
674+       createFakeTaskInfo(),
675+       null ))
676+ 
677+     //  Trigger resubmission of the failed map stage and finish the re-started map task.
678+     runEvent(ResubmitFailedStages )
679+ 
680+     //  The FetchFailed from the original reduce stage should be ignored.
681+     assert(countSubmittedMapStageAttempts() ===  2 )
682+   }
683+ 
550684  test(" ignore late map task completions"  ) {
551685    val  shuffleMapRdd  =  new  MyRDD (sc, 2 , Nil )
552686    val  shuffleDep  =  new  ShuffleDependency (shuffleMapRdd, null )
0 commit comments