@@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575
7676 // TaskSetManagers are not thread safe, so any access to one should be synchronized
7777 // on this class.
78- val activeTaskSets = new HashMap [String , TaskSetManager ]
7978 val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
8079
81- val taskIdToTaskSetId = new HashMap [Long , String ]
80+ val taskIdToStageIdAndAttempt = new HashMap [Long , ( Int , Int ) ]
8281 val taskIdToExecutorId = new HashMap [Long , String ]
8382
8483 @ volatile private var hasReceivedTask = false
@@ -163,10 +162,9 @@ private[spark] class TaskSchedulerImpl(
163162 logInfo(" Adding task set " + taskSet.id + " with " + tasks.length + " tasks" )
164163 this .synchronized {
165164 val manager = createTaskSetManager(taskSet, maxTaskFailures)
166- activeTaskSets(taskSet.id) = manager
167165 val stage = taskSet.stageId
168166 val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [Int , TaskSetManager ])
169- stageTaskSets(taskSet.attempt ) = manager
167+ stageTaskSets(taskSet.stageAttemptId ) = manager
170168 val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171169 ts.taskSet != taskSet && ! ts.isZombie
172170 }
@@ -203,19 +201,21 @@ private[spark] class TaskSchedulerImpl(
203201
204202 override def cancelTasks (stageId : Int , interruptThread : Boolean ): Unit = synchronized {
205203 logInfo(" Cancelling stage " + stageId)
206- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
207- // There are two possible cases here:
208- // 1. The task set manager has been created and some tasks have been scheduled.
209- // In this case, send a kill signal to the executors to kill the task and then abort
210- // the stage.
211- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
212- // simply abort the stage.
213- tsm.runningTasksSet.foreach { tid =>
214- val execId = taskIdToExecutorId(tid)
215- backend.killTask(tid, execId, interruptThread)
204+ taskSetsByStage.get(stageId).foreach { attempts =>
205+ attempts.foreach { case (_, tsm) =>
206+ // There are two possible cases here:
207+ // 1. The task set manager has been created and some tasks have been scheduled.
208+ // In this case, send a kill signal to the executors to kill the task and then abort
209+ // the stage.
210+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
211+ // simply abort the stage.
212+ tsm.runningTasksSet.foreach { tid =>
213+ val execId = taskIdToExecutorId(tid)
214+ backend.killTask(tid, execId, interruptThread)
215+ }
216+ tsm.abort(" Stage %s cancelled" .format(stageId))
217+ logInfo(" Stage %d was cancelled" .format(stageId))
216218 }
217- tsm.abort(" Stage %s cancelled" .format(stageId))
218- logInfo(" Stage %d was cancelled" .format(stageId))
219219 }
220220 }
221221
@@ -225,9 +225,8 @@ private[spark] class TaskSchedulerImpl(
225225 * cleaned up.
226226 */
227227 def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
228- activeTaskSets -= manager.taskSet.id
229228 taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230- taskSetsForStage -= manager.taskSet.attempt
229+ taskSetsForStage -= manager.taskSet.stageAttemptId
231230 if (taskSetsForStage.isEmpty) {
232231 taskSetsByStage -= manager.taskSet.stageId
233232 }
@@ -252,7 +251,7 @@ private[spark] class TaskSchedulerImpl(
252251 for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253252 tasks(i) += task
254253 val tid = task.taskId
255- taskIdToTaskSetId (tid) = taskSet.taskSet.id
254+ taskIdToStageIdAndAttempt (tid) = ( taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
256255 taskIdToExecutorId(tid) = execId
257256 executorsByHost(host) += execId
258257 availableCpus(i) -= CPUS_PER_TASK
@@ -336,26 +335,24 @@ private[spark] class TaskSchedulerImpl(
336335 failedExecutor = Some (execId)
337336 }
338337 }
339- taskIdToTaskSetId.get (tid) match {
340- case Some (taskSetId ) =>
338+ taskSetManagerForTask (tid) match {
339+ case Some (taskSet ) =>
341340 if (TaskState .isFinished(state)) {
342- taskIdToTaskSetId .remove(tid)
341+ taskIdToStageIdAndAttempt .remove(tid)
343342 taskIdToExecutorId.remove(tid)
344343 }
345- activeTaskSets.get(taskSetId).foreach { taskSet =>
346- if (state == TaskState .FINISHED ) {
347- taskSet.removeRunningTask(tid)
348- taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
349- } else if (Set (TaskState .FAILED , TaskState .KILLED , TaskState .LOST ).contains(state)) {
350- taskSet.removeRunningTask(tid)
351- taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
352- }
344+ if (state == TaskState .FINISHED ) {
345+ taskSet.removeRunningTask(tid)
346+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
347+ } else if (Set (TaskState .FAILED , TaskState .KILLED , TaskState .LOST ).contains(state)) {
348+ taskSet.removeRunningTask(tid)
349+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
353350 }
354351 case None =>
355352 logError(
356353 (" Ignoring update with state %s for TID %s because its task set is gone (this is " +
357- " likely the result of receiving duplicate task finished status updates)" )
358- .format(state, tid))
354+ " likely the result of receiving duplicate task finished status updates)" )
355+ .format(state, tid))
359356 }
360357 } catch {
361358 case e : Exception => logError(" Exception in statusUpdate" , e)
@@ -380,9 +377,13 @@ private[spark] class TaskSchedulerImpl(
380377
381378 val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
382379 taskMetrics.flatMap { case (id, metrics) =>
383- taskIdToTaskSetId.get(id)
384- .flatMap(activeTaskSets.get)
385- .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
380+ for {
381+ (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
382+ attempts <- taskSetsByStage.get(stageId)
383+ taskSetMgr <- attempts.get(stageAttemptId)
384+ } yield {
385+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
386+ }
386387 }
387388 }
388389 dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -414,9 +415,12 @@ private[spark] class TaskSchedulerImpl(
414415
415416 def error (message : String ) {
416417 synchronized {
417- if (activeTaskSets .nonEmpty) {
418+ if (taskSetsByStage .nonEmpty) {
418419 // Have each task set throw a SparkException with the error
419- for ((taskSetId, manager) <- activeTaskSets) {
420+ for {
421+ attempts <- taskSetsByStage.values
422+ manager <- attempts.values
423+ } {
420424 try {
421425 manager.abort(message)
422426 } catch {
@@ -537,6 +541,21 @@ private[spark] class TaskSchedulerImpl(
537541
538542 override def applicationAttemptId (): Option [String ] = backend.applicationAttemptId()
539543
544+ private [scheduler] def taskSetManagerForTask (taskId : Long ): Option [TaskSetManager ] = {
545+ taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
546+ taskSetManagerForAttempt(stageId, stageAttemptId)
547+ }
548+ }
549+
550+ private [scheduler] def taskSetManagerForAttempt (stageId : Int , stageAttemptId : Int ): Option [TaskSetManager ] = {
551+ for {
552+ attempts <- taskSetsByStage.get(stageId)
553+ manager <- attempts.get(stageAttemptId)
554+ } yield {
555+ manager
556+ }
557+ }
558+
540559}
541560
542561
0 commit comments