Skip to content

Commit c6cbb06

Browse files
committed
Still need mapId for the fetch fail scenario
1 parent 8c7460d commit c6cbb06

File tree

5 files changed

+99
-82
lines changed

5 files changed

+99
-82
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
282282

283283
// For testing
284284
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
285-
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
285+
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
286286
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
287287
}
288288

@@ -292,11 +292,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
292292
* endPartition is excluded from the range).
293293
*
294294
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
295-
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
296-
* describing the shuffle blocks that are stored at that block manager.
295+
* and the second item is a sequence of (shuffle block id, shuffle block size, map id)
296+
* tuples describing the shuffle blocks that are stored at that block manager.
297297
*/
298298
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
299-
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
299+
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
300300

301301
/**
302302
* Deletes map output status information for the specified shuffle stage.
@@ -646,7 +646,7 @@ private[spark] class MapOutputTrackerMaster(
646646
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
647647
// This method is only called in local-mode.
648648
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
649-
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
649+
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
650650
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
651651
shuffleStatuses.get(shuffleId) match {
652652
case Some (shuffleStatus) =>
@@ -683,7 +683,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
683683

684684
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
685685
override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
686-
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
686+
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
687687
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
688688
val statuses = getStatuses(shuffleId)
689689
try {
@@ -864,17 +864,17 @@ private[spark] object MapOutputTracker extends Logging {
864864
* @param endPartition End of map output partition ID range (excluded from range)
865865
* @param statuses List of map statuses, indexed by map ID.
866866
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
867-
* and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
868-
* describing the shuffle blocks that are stored at that block manager.
867+
* and the second item is a sequence of (shuffle block id, shuffle block size, map id)
868+
* tuples describing the shuffle blocks that are stored at that block manager.
869869
*/
870870
def convertMapStatuses(
871871
shuffleId: Int,
872872
startPartition: Int,
873873
endPartition: Int,
874-
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
874+
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
875875
assert (statuses != null)
876-
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
877-
statuses.foreach { status =>
876+
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
877+
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
878878
if (status == null) {
879879
val errorMessage = s"Missing an output location for shuffle $shuffleId"
880880
logError(errorMessage)
@@ -884,7 +884,7 @@ private[spark] object MapOutputTracker extends Logging {
884884
val size = status.getSizeForBlock(part)
885885
if (size != 0) {
886886
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
887-
((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size))
887+
((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size, mapId))
888888
}
889889
}
890890
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
4848
* @param shuffleClient [[BlockStoreClient]] for fetching remote blocks
4949
* @param blockManager [[BlockManager]] for reading local blocks
5050
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
51-
* For each block we also require the size (in bytes as a long field) in
52-
* order to throttle the memory usage. Note that zero-sized blocks are
53-
* already excluded, which happened in
51+
* For each block we also require two info: 1. the size (in bytes as a long
52+
* field) in order to throttle the memory usage; 2. the mapId for this
53+
* block, which indicate the index in the map stage of the block.
54+
* Note that zero-sized blocks are already excluded, which happened in
5455
* [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
5556
* @param streamWrapper A function to wrap the returned input stream.
5657
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
@@ -66,7 +67,7 @@ final class ShuffleBlockFetcherIterator(
6667
context: TaskContext,
6768
shuffleClient: BlockStoreClient,
6869
blockManager: BlockManager,
69-
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
70+
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
7071
streamWrapper: (BlockId, InputStream) => InputStream,
7172
maxBytesInFlight: Long,
7273
maxReqsInFlight: Int,
@@ -96,7 +97,7 @@ final class ShuffleBlockFetcherIterator(
9697
private[this] val startTimeNs = System.nanoTime()
9798

9899
/** Local blocks to fetch, excluding zero-sized blocks. */
99-
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
100+
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
100101

101102
/** Remote blocks to fetch, excluding zero-sized blocks. */
102103
private[this] val remoteBlocks = new HashSet[BlockId]()
@@ -198,7 +199,7 @@ final class ShuffleBlockFetcherIterator(
198199
while (iter.hasNext) {
199200
val result = iter.next()
200201
result match {
201-
case SuccessFetchResult(_, address, _, buf, _) =>
202+
case SuccessFetchResult(_, _, address, _, buf, _) =>
202203
if (address != blockManager.blockManagerId) {
203204
shuffleMetrics.incRemoteBytesRead(buf.size)
204205
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
@@ -223,9 +224,11 @@ final class ShuffleBlockFetcherIterator(
223224
bytesInFlight += req.size
224225
reqsInFlight += 1
225226

226-
// so we can look up the size of each blockID
227-
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
228-
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
227+
// so we can look up the block info of each blockID
228+
val infoMap = req.blocks.map {
229+
case (blockId, size, mapId) => (blockId.toString, (size, mapId))
230+
}.toMap
231+
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
229232
val blockIds = req.blocks.map(_._1.toString)
230233
val address = req.address
231234

@@ -239,8 +242,8 @@ final class ShuffleBlockFetcherIterator(
239242
// This needs to be released after use.
240243
buf.retain()
241244
remainingBlocks -= blockId
242-
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
243-
remainingBlocks.isEmpty))
245+
results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
246+
address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
244247
logDebug("remainingBlocks: " + remainingBlocks)
245248
}
246249
}
@@ -249,7 +252,7 @@ final class ShuffleBlockFetcherIterator(
249252

250253
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
251254
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
252-
results.put(new FailureFetchResult(BlockId(blockId), address, e))
255+
results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
253256
}
254257
}
255258

@@ -282,28 +285,28 @@ final class ShuffleBlockFetcherIterator(
282285
for ((address, blockInfos) <- blocksByAddress) {
283286
if (address.executorId == blockManager.blockManagerId.executorId) {
284287
blockInfos.find(_._2 <= 0) match {
285-
case Some((blockId, size)) if size < 0 =>
288+
case Some((blockId, size, _)) if size < 0 =>
286289
throw new BlockException(blockId, "Negative block size " + size)
287-
case Some((blockId, size)) if size == 0 =>
290+
case Some((blockId, size, _)) if size == 0 =>
288291
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
289292
case None => // do nothing.
290293
}
291-
localBlocks ++= blockInfos.map(_._1)
294+
localBlocks ++= blockInfos.map(info => (info._1, info._3))
292295
localBlockBytes += blockInfos.map(_._2).sum
293296
numBlocksToFetch += localBlocks.size
294297
} else {
295298
val iterator = blockInfos.iterator
296299
var curRequestSize = 0L
297-
var curBlocks = new ArrayBuffer[(BlockId, Long)]
300+
var curBlocks = new ArrayBuffer[(BlockId, Long, Int)]
298301
while (iterator.hasNext) {
299-
val (blockId, size) = iterator.next()
302+
val (blockId, size, mapId) = iterator.next()
300303
remoteBlockBytes += size
301304
if (size < 0) {
302305
throw new BlockException(blockId, "Negative block size " + size)
303306
} else if (size == 0) {
304307
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
305308
} else {
306-
curBlocks += ((blockId, size))
309+
curBlocks += ((blockId, size, mapId))
307310
remoteBlocks += blockId
308311
numBlocksToFetch += 1
309312
curRequestSize += size
@@ -314,7 +317,7 @@ final class ShuffleBlockFetcherIterator(
314317
remoteRequests += new FetchRequest(address, curBlocks)
315318
logDebug(s"Creating fetch request of $curRequestSize at $address "
316319
+ s"with ${curBlocks.size} blocks")
317-
curBlocks = new ArrayBuffer[(BlockId, Long)]
320+
curBlocks = new ArrayBuffer[(BlockId, Long, Int)]
318321
curRequestSize = 0
319322
}
320323
}
@@ -340,19 +343,19 @@ final class ShuffleBlockFetcherIterator(
340343
logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
341344
val iter = localBlocks.iterator
342345
while (iter.hasNext) {
343-
val blockId = iter.next()
346+
val (blockId, mapId) = iter.next()
344347
try {
345348
val buf = blockManager.getBlockData(blockId)
346349
shuffleMetrics.incLocalBlocksFetched(1)
347350
shuffleMetrics.incLocalBytesRead(buf.size)
348351
buf.retain()
349-
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
352+
results.put(new SuccessFetchResult(blockId, mapId, blockManager.blockManagerId,
350353
buf.size(), buf, false))
351354
} catch {
352355
case e: Exception =>
353356
// If we see an exception, stop immediately.
354357
logError(s"Error occurred while fetching local blocks", e)
355-
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
358+
results.put(new FailureFetchResult(blockId, mapId, blockManager.blockManagerId, e))
356359
return
357360
}
358361
}
@@ -412,7 +415,7 @@ final class ShuffleBlockFetcherIterator(
412415
shuffleMetrics.incFetchWaitTime(fetchWaitTime)
413416

414417
result match {
415-
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
418+
case r @ SuccessFetchResult(blockId, mapId, address, size, buf, isNetworkReqDone) =>
416419
if (address != blockManager.blockManagerId) {
417420
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
418421
shuffleMetrics.incRemoteBytesRead(buf.size)
@@ -421,7 +424,7 @@ final class ShuffleBlockFetcherIterator(
421424
}
422425
shuffleMetrics.incRemoteBlocksFetched(1)
423426
}
424-
if (!localBlocks.contains(blockId)) {
427+
if (!localBlocks.contains((blockId, mapId))) {
425428
bytesInFlight -= size
426429
}
427430
if (isNetworkReqDone) {
@@ -445,7 +448,7 @@ final class ShuffleBlockFetcherIterator(
445448
// since the last call.
446449
val msg = s"Received a zero-size buffer for block $blockId from $address " +
447450
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
448-
throwFetchFailedException(blockId, address, new IOException(msg))
451+
throwFetchFailedException(blockId, mapId, address, new IOException(msg))
449452
}
450453

451454
val in = try {
@@ -456,7 +459,7 @@ final class ShuffleBlockFetcherIterator(
456459
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
457460
logError("Failed to create input stream from local block", e)
458461
buf.release()
459-
throwFetchFailedException(blockId, address, e)
462+
throwFetchFailedException(blockId, mapId, address, e)
460463
}
461464
try {
462465
input = streamWrapper(blockId, in)
@@ -474,11 +477,11 @@ final class ShuffleBlockFetcherIterator(
474477
buf.release()
475478
if (buf.isInstanceOf[FileSegmentManagedBuffer]
476479
|| corruptedBlocks.contains(blockId)) {
477-
throwFetchFailedException(blockId, address, e)
480+
throwFetchFailedException(blockId, mapId, address, e)
478481
} else {
479482
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
480483
corruptedBlocks += blockId
481-
fetchRequests += FetchRequest(address, Array((blockId, size)))
484+
fetchRequests += FetchRequest(address, Array((blockId, size, mapId)))
482485
result = null
483486
}
484487
} finally {
@@ -490,8 +493,8 @@ final class ShuffleBlockFetcherIterator(
490493
}
491494
}
492495

493-
case FailureFetchResult(blockId, address, e) =>
494-
throwFetchFailedException(blockId, address, e)
496+
case FailureFetchResult(blockId, mapId, address, e) =>
497+
throwFetchFailedException(blockId, mapId, address, e)
495498
}
496499

497500
// Send fetch requests up to maxBytesInFlight
@@ -504,6 +507,7 @@ final class ShuffleBlockFetcherIterator(
504507
input,
505508
this,
506509
currentResult.blockId,
510+
currentResult.mapId,
507511
currentResult.address,
508512
detectCorrupt && streamCompressedOrEncrypted))
509513
}
@@ -570,10 +574,11 @@ final class ShuffleBlockFetcherIterator(
570574

571575
private[storage] def throwFetchFailedException(
572576
blockId: BlockId,
577+
mapId: Int,
573578
address: BlockManagerId,
574579
e: Throwable) = {
575580
blockId match {
576-
case ShuffleBlockId(shufId, mapId, reduceId) =>
581+
case ShuffleBlockId(shufId, _, reduceId) =>
577582
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
578583
case _ =>
579584
throw new SparkException(
@@ -591,6 +596,7 @@ private class BufferReleasingInputStream(
591596
private[storage] val delegate: InputStream,
592597
private val iterator: ShuffleBlockFetcherIterator,
593598
private val blockId: BlockId,
599+
private val mapId: Int,
594600
private val address: BlockManagerId,
595601
private val detectCorruption: Boolean)
596602
extends InputStream {
@@ -602,7 +608,7 @@ private class BufferReleasingInputStream(
602608
} catch {
603609
case e: IOException if detectCorruption =>
604610
IOUtils.closeQuietly(this)
605-
iterator.throwFetchFailedException(blockId, address, e)
611+
iterator.throwFetchFailedException(blockId, mapId, address, e)
606612
}
607613
}
608614

@@ -624,7 +630,7 @@ private class BufferReleasingInputStream(
624630
} catch {
625631
case e: IOException if detectCorruption =>
626632
IOUtils.closeQuietly(this)
627-
iterator.throwFetchFailedException(blockId, address, e)
633+
iterator.throwFetchFailedException(blockId, mapId, address, e)
628634
}
629635
}
630636

@@ -636,7 +642,7 @@ private class BufferReleasingInputStream(
636642
} catch {
637643
case e: IOException if detectCorruption =>
638644
IOUtils.closeQuietly(this)
639-
iterator.throwFetchFailedException(blockId, address, e)
645+
iterator.throwFetchFailedException(blockId, mapId, address, e)
640646
}
641647
}
642648

@@ -646,7 +652,7 @@ private class BufferReleasingInputStream(
646652
} catch {
647653
case e: IOException if detectCorruption =>
648654
IOUtils.closeQuietly(this)
649-
iterator.throwFetchFailedException(blockId, address, e)
655+
iterator.throwFetchFailedException(blockId, mapId, address, e)
650656
}
651657
}
652658

@@ -681,9 +687,10 @@ object ShuffleBlockFetcherIterator {
681687
* A request to fetch blocks from a remote BlockManager.
682688
* @param address remote BlockManager to fetch from.
683689
* @param blocks Sequence of tuple, where the first element is the block id,
684-
* and the second element is the estimated size, used to calculate bytesInFlight.
690+
* and the second element is the estimated size, used to calculate bytesInFlight,
691+
* the third element is the mapId.
685692
*/
686-
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
693+
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long, Int)]) {
687694
val size = blocks.map(_._2).sum
688695
}
689696

@@ -698,6 +705,7 @@ object ShuffleBlockFetcherIterator {
698705
/**
699706
* Result of a fetch from a remote block successfully.
700707
* @param blockId block id
708+
* @param mapId mapId for this block
701709
* @param address BlockManager that the block was fetched from.
702710
* @param size estimated size of the block. Note that this is NOT the exact bytes.
703711
* Size of remote block is used to calculate bytesInFlight.
@@ -706,6 +714,7 @@ object ShuffleBlockFetcherIterator {
706714
*/
707715
private[storage] case class SuccessFetchResult(
708716
blockId: BlockId,
717+
mapId: Int,
709718
address: BlockManagerId,
710719
size: Long,
711720
buf: ManagedBuffer,
@@ -717,11 +726,13 @@ object ShuffleBlockFetcherIterator {
717726
/**
718727
* Result of a fetch from a remote block unsuccessfully.
719728
* @param blockId block id
729+
* @param mapId mapId for this block
720730
* @param address BlockManager that the block was attempted to be fetched from
721731
* @param e the failure exception
722732
*/
723733
private[storage] case class FailureFetchResult(
724734
blockId: BlockId,
735+
mapId: Int,
725736
address: BlockManagerId,
726737
e: Throwable)
727738
extends FetchResult

0 commit comments

Comments
 (0)