@@ -76,6 +76,7 @@ private[spark] class TaskSchedulerImpl(
7676 // TaskSetManagers are not thread safe, so any access to one should be synchronized
7777 // on this class.
7878 val activeTaskSets = new HashMap [String , TaskSetManager ]
79+ val taskSetsByStage = new HashMap [Int , HashMap [Int , TaskSetManager ]]
7980
8081 val taskIdToTaskSetId = new HashMap [Long , String ]
8182 val taskIdToExecutorId = new HashMap [Long , String ]
@@ -164,13 +165,14 @@ private[spark] class TaskSchedulerImpl(
164165 val manager = createTaskSetManager(taskSet, maxTaskFailures)
165166 activeTaskSets(taskSet.id) = manager
166167 val stage = taskSet.stageId
167- val conflictingTaskSet = activeTaskSets.exists { case (id, ts) =>
168- // if the id matches, it really should be the same taskSet, but in some unit tests
169- // we add new taskSets with the same id
170- id != taskSet.id && ! ts.isZombie && ts.stageId == stage
168+ val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap [ Int , TaskSetManager ])
169+ stageTaskSets( taskSet.attempt) = manager
170+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
171+ ts.taskSet != taskSet && ! ts.isZombie
171172 }
172173 if (conflictingTaskSet) {
173- throw new SparkIllegalStateException (s " more than one active taskSet for stage $stage" )
174+ throw new SparkIllegalStateException (s " more than one active taskSet for stage $stage: " +
175+ s " ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(" ," )}" )
174176 }
175177 schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
176178
@@ -224,6 +226,12 @@ private[spark] class TaskSchedulerImpl(
224226 */
225227 def taskSetFinished (manager : TaskSetManager ): Unit = synchronized {
226228 activeTaskSets -= manager.taskSet.id
229+ taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
230+ taskSetsForStage -= manager.taskSet.attempt
231+ if (taskSetsForStage.isEmpty) {
232+ taskSetsByStage -= manager.taskSet.stageId
233+ }
234+ }
227235 manager.parent.removeSchedulable(manager)
228236 logInfo(" Removed TaskSet %s, whose tasks have all completed, from pool %s"
229237 .format(manager.taskSet.id, manager.parent.name))
0 commit comments