Skip to content

Commit cbff280

Browse files
Andy Sloanesrowen
authored andcommitted
[SPARK-13631][CORE] Thread-safe getLocationsWithLargestOutputs
## What changes were proposed in this pull request? If a job is being scheduled in one thread which has a dependency on an RDD currently executing a shuffle in another thread, Spark would throw a NullPointerException. This patch synchronizes access to `mapStatuses` and skips null status entries (which are in-progress shuffle tasks). ## How was this patch tested? Our client code unit test suite, which was reliably reproducing the race condition with 10 threads, shows that this fixes it. I have not found a minimal test case to add to Spark, but I will attempt to do so if desired. The same test case was tripping up on SPARK-4454, which was fixed by making other DAGScheduler code thread-safe. shivaram srowen Author: Andy Sloane <[email protected]> Closes #11505 from a1k0n/SPARK-13631.
1 parent 2c5af7d commit cbff280

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
376376
* @param numReducers total number of reducers in the shuffle
377377
* @param fractionThreshold fraction of total map output size that a location must have
378378
* for it to be considered large.
379-
*
380-
* This method is not thread-safe.
381379
*/
382380
def getLocationsWithLargestOutputs(
383381
shuffleId: Int,
@@ -386,28 +384,36 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
386384
fractionThreshold: Double)
387385
: Option[Array[BlockManagerId]] = {
388386

389-
if (mapStatuses.contains(shuffleId)) {
390-
val statuses = mapStatuses(shuffleId)
391-
if (statuses.nonEmpty) {
392-
// HashMap to add up sizes of all blocks at the same location
393-
val locs = new HashMap[BlockManagerId, Long]
394-
var totalOutputSize = 0L
395-
var mapIdx = 0
396-
while (mapIdx < statuses.length) {
397-
val status = statuses(mapIdx)
398-
val blockSize = status.getSizeForBlock(reducerId)
399-
if (blockSize > 0) {
400-
locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
401-
totalOutputSize += blockSize
387+
val statuses = mapStatuses.get(shuffleId).orNull
388+
if (statuses != null) {
389+
statuses.synchronized {
390+
if (statuses.nonEmpty) {
391+
// HashMap to add up sizes of all blocks at the same location
392+
val locs = new HashMap[BlockManagerId, Long]
393+
var totalOutputSize = 0L
394+
var mapIdx = 0
395+
while (mapIdx < statuses.length) {
396+
val status = statuses(mapIdx)
397+
// status may be null here if we are called between registerShuffle, which creates an
398+
// array with null entries for each output, and registerMapOutputs, which populates it
399+
// with valid status entries. This is possible if one thread schedules a job which
400+
// depends on an RDD which is currently being computed by another thread.
401+
if (status != null) {
402+
val blockSize = status.getSizeForBlock(reducerId)
403+
if (blockSize > 0) {
404+
locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize
405+
totalOutputSize += blockSize
406+
}
407+
}
408+
mapIdx = mapIdx + 1
409+
}
410+
val topLocs = locs.filter { case (loc, size) =>
411+
size.toDouble / totalOutputSize >= fractionThreshold
412+
}
413+
// Return if we have any locations which satisfy the required threshold
414+
if (topLocs.nonEmpty) {
415+
return Some(topLocs.keys.toArray)
402416
}
403-
mapIdx = mapIdx + 1
404-
}
405-
val topLocs = locs.filter { case (loc, size) =>
406-
size.toDouble / totalOutputSize >= fractionThreshold
407-
}
408-
// Return if we have any locations which satisfy the required threshold
409-
if (topLocs.nonEmpty) {
410-
return Some(topLocs.map(_._1).toArray)
411417
}
412418
}
413419
}

0 commit comments

Comments
 (0)