@@ -75,7 +75,7 @@ private[spark] class TaskSchedulerImpl(
7575
7676  //  TaskSetManagers are not thread safe, so any access to one should be synchronized
7777  //  on this class.
78-   val  taskSetsByStage  =  new  HashMap [Int , HashMap [Int , TaskSetManager ]]
78+   val  taskSetsByStageIdAndAttempt  =  new  HashMap [Int , HashMap [Int , TaskSetManager ]]
7979
8080  val  taskIdToStageIdAndAttempt  =  new  HashMap [Long , (Int , Int )]
8181  val  taskIdToExecutorId  =  new  HashMap [Long , String ]
@@ -163,7 +163,8 @@ private[spark] class TaskSchedulerImpl(
163163    this .synchronized  {
164164      val  manager  =  createTaskSetManager(taskSet, maxTaskFailures)
165165      val  stage  =  taskSet.stageId
166-       val  stageTaskSets  =  taskSetsByStage.getOrElseUpdate(stage, new  HashMap [Int , TaskSetManager ])
166+       val  stageTaskSets  = 
167+         taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new  HashMap [Int , TaskSetManager ])
167168      stageTaskSets(taskSet.stageAttemptId) =  manager
168169      val  conflictingTaskSet  =  stageTaskSets.exists { case  (_, ts) => 
169170        ts.taskSet !=  taskSet &&  ! ts.isZombie
@@ -201,7 +202,7 @@ private[spark] class TaskSchedulerImpl(
201202
202203  override  def  cancelTasks (stageId : Int , interruptThread : Boolean ):  Unit  =  synchronized  {
203204    logInfo(" Cancelling stage " +  stageId)
204-     taskSetsByStage .get(stageId).foreach { attempts => 
205+     taskSetsByStageIdAndAttempt .get(stageId).foreach { attempts => 
205206      attempts.foreach { case  (_, tsm) => 
206207        //  There are two possible cases here:
207208        //  1. The task set manager has been created and some tasks have been scheduled.
@@ -225,10 +226,10 @@ private[spark] class TaskSchedulerImpl(
225226   * cleaned up. 
226227   */  
227228  def  taskSetFinished (manager : TaskSetManager ):  Unit  =  synchronized  {
228-     taskSetsByStage .get(manager.taskSet.stageId).foreach { taskSetsForStage => 
229+     taskSetsByStageIdAndAttempt .get(manager.taskSet.stageId).foreach { taskSetsForStage => 
229230      taskSetsForStage -=  manager.taskSet.stageAttemptId
230231      if  (taskSetsForStage.isEmpty) {
231-         taskSetsByStage  -=  manager.taskSet.stageId
232+         taskSetsByStageIdAndAttempt  -=  manager.taskSet.stageId
232233      }
233234    }
234235    manager.parent.removeSchedulable(manager)
@@ -380,7 +381,7 @@ private[spark] class TaskSchedulerImpl(
380381      taskMetrics.flatMap { case  (id, metrics) => 
381382        for  {
382383          (stageId, stageAttemptId) <-  taskIdToStageIdAndAttempt.get(id)
383-           attempts <-  taskSetsByStage .get(stageId)
384+           attempts <-  taskSetsByStageIdAndAttempt .get(stageId)
384385          taskSetMgr <-  attempts.get(stageAttemptId)
385386        } yield  {
386387            (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
@@ -416,10 +417,10 @@ private[spark] class TaskSchedulerImpl(
416417
417418  def  error (message : String ) {
418419    synchronized  {
419-       if  (taskSetsByStage .nonEmpty) {
420+       if  (taskSetsByStageIdAndAttempt .nonEmpty) {
420421        //  Have each task set throw a SparkException with the error
421422        for  {
422-           attempts <-  taskSetsByStage .values
423+           attempts <-  taskSetsByStageIdAndAttempt .values
423424          manager <-  attempts.values
424425        } {
425426          try  {
@@ -552,7 +553,7 @@ private[spark] class TaskSchedulerImpl(
552553      stageId : Int ,
553554      stageAttemptId : Int ):  Option [TaskSetManager ] =  {
554555    for  {
555-       attempts <-  taskSetsByStage .get(stageId)
556+       attempts <-  taskSetsByStageIdAndAttempt .get(stageId)
556557      manager <-  attempts.get(stageAttemptId)
557558    } yield  {
558559      manager
0 commit comments