@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
7575
7676 // TaskSetManagers are not thread safe, so any access to one should be synchronized
7777 // on this class.
78- val taskSetsByStageIdAndAttempt = new HashMap [Int , HashMap [Int , TaskSetManager ]]
78+ private val taskSetsByStageIdAndAttempt = new HashMap [Int , HashMap [Int , TaskSetManager ]]
7979
80- val taskIdToStageIdAndAttempt = new HashMap [Long , ( Int , Int ) ]
80+ private [scheduler] val taskIdToTaskSetManager = new HashMap [Long , TaskSetManager ]
8181 val taskIdToExecutorId = new HashMap [Long , String ]
8282
8383 @ volatile private var hasReceivedTask = false
@@ -252,8 +252,7 @@ private[spark] class TaskSchedulerImpl(
252252 for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
253253 tasks(i) += task
254254 val tid = task.taskId
255- taskIdToStageIdAndAttempt(tid) =
256- (taskSet.taskSet.stageId, taskSet.taskSet.stageAttemptId)
255+ taskIdToTaskSetManager(tid) = taskSet
257256 taskIdToExecutorId(tid) = execId
258257 executorsByHost(host) += execId
259258 availableCpus(i) -= CPUS_PER_TASK
@@ -337,10 +336,10 @@ private[spark] class TaskSchedulerImpl(
337336 failedExecutor = Some (execId)
338337 }
339338 }
340- taskSetManagerForTask (tid) match {
339+ taskIdToTaskSetManager.get (tid) match {
341340 case Some (taskSet) =>
342341 if (TaskState .isFinished(state)) {
343- taskIdToStageIdAndAttempt .remove(tid)
342+ taskIdToTaskSetManager .remove(tid)
344343 taskIdToExecutorId.remove(tid)
345344 }
346345 if (state == TaskState .FINISHED ) {
@@ -379,12 +378,8 @@ private[spark] class TaskSchedulerImpl(
379378
380379 val metricsWithStageIds : Array [(Long , Int , Int , TaskMetrics )] = synchronized {
381380 taskMetrics.flatMap { case (id, metrics) =>
382- for {
383- (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id)
384- attempts <- taskSetsByStageIdAndAttempt.get(stageId)
385- taskSetMgr <- attempts.get(stageAttemptId)
386- } yield {
387- (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
381+ taskIdToTaskSetManager.get(id).map { taskSetMgr =>
382+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
388383 }
389384 }
390385 }
@@ -543,12 +538,6 @@ private[spark] class TaskSchedulerImpl(
543538
544539 override def applicationAttemptId (): Option [String ] = backend.applicationAttemptId()
545540
546- private [scheduler] def taskSetManagerForTask (taskId : Long ): Option [TaskSetManager ] = {
547- taskIdToStageIdAndAttempt.get(taskId).flatMap{ case (stageId, stageAttemptId) =>
548- taskSetManagerForAttempt(stageId, stageAttemptId)
549- }
550- }
551-
552541 private [scheduler] def taskSetManagerForAttempt (
553542 stageId : Int ,
554543 stageAttemptId : Int ): Option [TaskSetManager ] = {
0 commit comments