@@ -48,9 +48,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
4848 * @param shuffleClient [[BlockStoreClient ]] for fetching remote blocks
4949 * @param blockManager [[BlockManager ]] for reading local blocks
5050 * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId ]].
51- * For each block we also require the size (in bytes as a long field) in
52- * order to throttle the memory usage. Note that zero-sized blocks are
53- * already excluded, which happened in
51+ * For each block we also require two info: 1. the size (in bytes as a long
52+ * field) in order to throttle the memory usage; 2. the mapId for this
53+ * block, which indicate the index in the map stage of the block.
54+ * Note that zero-sized blocks are already excluded, which happened in
5455 * [[org.apache.spark.MapOutputTracker.convertMapStatuses ]].
5556 * @param streamWrapper A function to wrap the returned input stream.
5657 * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
@@ -66,7 +67,7 @@ final class ShuffleBlockFetcherIterator(
6667 context : TaskContext ,
6768 shuffleClient : BlockStoreClient ,
6869 blockManager : BlockManager ,
69- blocksByAddress : Iterator [(BlockManagerId , Seq [(BlockId , Long )])],
70+ blocksByAddress : Iterator [(BlockManagerId , Seq [(BlockId , Long , Int )])],
7071 streamWrapper : (BlockId , InputStream ) => InputStream ,
7172 maxBytesInFlight : Long ,
7273 maxReqsInFlight : Int ,
@@ -96,7 +97,7 @@ final class ShuffleBlockFetcherIterator(
9697 private [this ] val startTimeNs = System .nanoTime()
9798
9899 /** Local blocks to fetch, excluding zero-sized blocks. */
99- private [this ] val localBlocks = scala.collection.mutable.LinkedHashSet [BlockId ]()
100+ private [this ] val localBlocks = scala.collection.mutable.LinkedHashSet [( BlockId , Int ) ]()
100101
101102 /** Remote blocks to fetch, excluding zero-sized blocks. */
102103 private [this ] val remoteBlocks = new HashSet [BlockId ]()
@@ -198,7 +199,7 @@ final class ShuffleBlockFetcherIterator(
198199 while (iter.hasNext) {
199200 val result = iter.next()
200201 result match {
201- case SuccessFetchResult (_, address, _, buf, _) =>
202+ case SuccessFetchResult (_, _, address, _, buf, _) =>
202203 if (address != blockManager.blockManagerId) {
203204 shuffleMetrics.incRemoteBytesRead(buf.size)
204205 if (buf.isInstanceOf [FileSegmentManagedBuffer ]) {
@@ -223,9 +224,11 @@ final class ShuffleBlockFetcherIterator(
223224 bytesInFlight += req.size
224225 reqsInFlight += 1
225226
226- // so we can look up the size of each blockID
227- val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
228- val remainingBlocks = new HashSet [String ]() ++= sizeMap.keys
227+ // so we can look up the block info of each blockID
228+ val infoMap = req.blocks.map {
229+ case (blockId, size, mapId) => (blockId.toString, (size, mapId))
230+ }.toMap
231+ val remainingBlocks = new HashSet [String ]() ++= infoMap.keys
229232 val blockIds = req.blocks.map(_._1.toString)
230233 val address = req.address
231234
@@ -239,8 +242,8 @@ final class ShuffleBlockFetcherIterator(
239242 // This needs to be released after use.
240243 buf.retain()
241244 remainingBlocks -= blockId
242- results.put(new SuccessFetchResult (BlockId (blockId), address, sizeMap (blockId), buf ,
243- remainingBlocks.isEmpty))
245+ results.put(new SuccessFetchResult (BlockId (blockId), infoMap (blockId)._2 ,
246+ address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
244247 logDebug(" remainingBlocks: " + remainingBlocks)
245248 }
246249 }
@@ -249,7 +252,7 @@ final class ShuffleBlockFetcherIterator(
249252
250253 override def onBlockFetchFailure (blockId : String , e : Throwable ): Unit = {
251254 logError(s " Failed to get block(s) from ${req.address.host}: ${req.address.port}" , e)
252- results.put(new FailureFetchResult (BlockId (blockId), address, e))
255+ results.put(new FailureFetchResult (BlockId (blockId), infoMap(blockId)._2, address, e))
253256 }
254257 }
255258
@@ -282,28 +285,28 @@ final class ShuffleBlockFetcherIterator(
282285 for ((address, blockInfos) <- blocksByAddress) {
283286 if (address.executorId == blockManager.blockManagerId.executorId) {
284287 blockInfos.find(_._2 <= 0 ) match {
285- case Some ((blockId, size)) if size < 0 =>
288+ case Some ((blockId, size, _ )) if size < 0 =>
286289 throw new BlockException (blockId, " Negative block size " + size)
287- case Some ((blockId, size)) if size == 0 =>
290+ case Some ((blockId, size, _ )) if size == 0 =>
288291 throw new BlockException (blockId, " Zero-sized blocks should be excluded." )
289292 case None => // do nothing.
290293 }
291- localBlocks ++= blockInfos.map(_ ._1)
294+ localBlocks ++= blockInfos.map(info => (info ._1, info._3) )
292295 localBlockBytes += blockInfos.map(_._2).sum
293296 numBlocksToFetch += localBlocks.size
294297 } else {
295298 val iterator = blockInfos.iterator
296299 var curRequestSize = 0L
297- var curBlocks = new ArrayBuffer [(BlockId , Long )]
300+ var curBlocks = new ArrayBuffer [(BlockId , Long , Int )]
298301 while (iterator.hasNext) {
299- val (blockId, size) = iterator.next()
302+ val (blockId, size, mapId ) = iterator.next()
300303 remoteBlockBytes += size
301304 if (size < 0 ) {
302305 throw new BlockException (blockId, " Negative block size " + size)
303306 } else if (size == 0 ) {
304307 throw new BlockException (blockId, " Zero-sized blocks should be excluded." )
305308 } else {
306- curBlocks += ((blockId, size))
309+ curBlocks += ((blockId, size, mapId ))
307310 remoteBlocks += blockId
308311 numBlocksToFetch += 1
309312 curRequestSize += size
@@ -314,7 +317,7 @@ final class ShuffleBlockFetcherIterator(
314317 remoteRequests += new FetchRequest (address, curBlocks)
315318 logDebug(s " Creating fetch request of $curRequestSize at $address "
316319 + s " with ${curBlocks.size} blocks " )
317- curBlocks = new ArrayBuffer [(BlockId , Long )]
320+ curBlocks = new ArrayBuffer [(BlockId , Long , Int )]
318321 curRequestSize = 0
319322 }
320323 }
@@ -340,19 +343,19 @@ final class ShuffleBlockFetcherIterator(
340343 logDebug(s " Start fetching local blocks: ${localBlocks.mkString(" , " )}" )
341344 val iter = localBlocks.iterator
342345 while (iter.hasNext) {
343- val blockId = iter.next()
346+ val ( blockId, mapId) = iter.next()
344347 try {
345348 val buf = blockManager.getBlockData(blockId)
346349 shuffleMetrics.incLocalBlocksFetched(1 )
347350 shuffleMetrics.incLocalBytesRead(buf.size)
348351 buf.retain()
349- results.put(new SuccessFetchResult (blockId, blockManager.blockManagerId,
352+ results.put(new SuccessFetchResult (blockId, mapId, blockManager.blockManagerId,
350353 buf.size(), buf, false ))
351354 } catch {
352355 case e : Exception =>
353356 // If we see an exception, stop immediately.
354357 logError(s " Error occurred while fetching local blocks " , e)
355- results.put(new FailureFetchResult (blockId, blockManager.blockManagerId, e))
358+ results.put(new FailureFetchResult (blockId, mapId, blockManager.blockManagerId, e))
356359 return
357360 }
358361 }
@@ -412,7 +415,7 @@ final class ShuffleBlockFetcherIterator(
412415 shuffleMetrics.incFetchWaitTime(fetchWaitTime)
413416
414417 result match {
415- case r @ SuccessFetchResult (blockId, address, size, buf, isNetworkReqDone) =>
418+ case r @ SuccessFetchResult (blockId, mapId, address, size, buf, isNetworkReqDone) =>
416419 if (address != blockManager.blockManagerId) {
417420 numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
418421 shuffleMetrics.incRemoteBytesRead(buf.size)
@@ -421,7 +424,7 @@ final class ShuffleBlockFetcherIterator(
421424 }
422425 shuffleMetrics.incRemoteBlocksFetched(1 )
423426 }
424- if (! localBlocks.contains(blockId)) {
427+ if (! localBlocks.contains(( blockId, mapId) )) {
425428 bytesInFlight -= size
426429 }
427430 if (isNetworkReqDone) {
@@ -445,7 +448,7 @@ final class ShuffleBlockFetcherIterator(
445448 // since the last call.
446449 val msg = s " Received a zero-size buffer for block $blockId from $address " +
447450 s " (expectedApproxSize = $size, isNetworkReqDone= $isNetworkReqDone) "
448- throwFetchFailedException(blockId, address, new IOException (msg))
451+ throwFetchFailedException(blockId, mapId, address, new IOException (msg))
449452 }
450453
451454 val in = try {
@@ -456,7 +459,7 @@ final class ShuffleBlockFetcherIterator(
456459 assert(buf.isInstanceOf [FileSegmentManagedBuffer ])
457460 logError(" Failed to create input stream from local block" , e)
458461 buf.release()
459- throwFetchFailedException(blockId, address, e)
462+ throwFetchFailedException(blockId, mapId, address, e)
460463 }
461464 try {
462465 input = streamWrapper(blockId, in)
@@ -474,11 +477,11 @@ final class ShuffleBlockFetcherIterator(
474477 buf.release()
475478 if (buf.isInstanceOf [FileSegmentManagedBuffer ]
476479 || corruptedBlocks.contains(blockId)) {
477- throwFetchFailedException(blockId, address, e)
480+ throwFetchFailedException(blockId, mapId, address, e)
478481 } else {
479482 logWarning(s " got an corrupted block $blockId from $address, fetch again " , e)
480483 corruptedBlocks += blockId
481- fetchRequests += FetchRequest (address, Array ((blockId, size)))
484+ fetchRequests += FetchRequest (address, Array ((blockId, size, mapId )))
482485 result = null
483486 }
484487 } finally {
@@ -490,8 +493,8 @@ final class ShuffleBlockFetcherIterator(
490493 }
491494 }
492495
493- case FailureFetchResult (blockId, address, e) =>
494- throwFetchFailedException(blockId, address, e)
496+ case FailureFetchResult (blockId, mapId, address, e) =>
497+ throwFetchFailedException(blockId, mapId, address, e)
495498 }
496499
497500 // Send fetch requests up to maxBytesInFlight
@@ -504,6 +507,7 @@ final class ShuffleBlockFetcherIterator(
504507 input,
505508 this ,
506509 currentResult.blockId,
510+ currentResult.mapId,
507511 currentResult.address,
508512 detectCorrupt && streamCompressedOrEncrypted))
509513 }
@@ -570,10 +574,11 @@ final class ShuffleBlockFetcherIterator(
570574
571575 private [storage] def throwFetchFailedException (
572576 blockId : BlockId ,
577+ mapId : Int ,
573578 address : BlockManagerId ,
574579 e : Throwable ) = {
575580 blockId match {
576- case ShuffleBlockId (shufId, mapId , reduceId) =>
581+ case ShuffleBlockId (shufId, _ , reduceId) =>
577582 throw new FetchFailedException (address, shufId.toInt, mapId.toInt, reduceId, e)
578583 case _ =>
579584 throw new SparkException (
@@ -591,6 +596,7 @@ private class BufferReleasingInputStream(
591596 private [storage] val delegate : InputStream ,
592597 private val iterator : ShuffleBlockFetcherIterator ,
593598 private val blockId : BlockId ,
599+ private val mapId : Int ,
594600 private val address : BlockManagerId ,
595601 private val detectCorruption : Boolean )
596602 extends InputStream {
@@ -602,7 +608,7 @@ private class BufferReleasingInputStream(
602608 } catch {
603609 case e : IOException if detectCorruption =>
604610 IOUtils .closeQuietly(this )
605- iterator.throwFetchFailedException(blockId, address, e)
611+ iterator.throwFetchFailedException(blockId, mapId, address, e)
606612 }
607613 }
608614
@@ -624,7 +630,7 @@ private class BufferReleasingInputStream(
624630 } catch {
625631 case e : IOException if detectCorruption =>
626632 IOUtils .closeQuietly(this )
627- iterator.throwFetchFailedException(blockId, address, e)
633+ iterator.throwFetchFailedException(blockId, mapId, address, e)
628634 }
629635 }
630636
@@ -636,7 +642,7 @@ private class BufferReleasingInputStream(
636642 } catch {
637643 case e : IOException if detectCorruption =>
638644 IOUtils .closeQuietly(this )
639- iterator.throwFetchFailedException(blockId, address, e)
645+ iterator.throwFetchFailedException(blockId, mapId, address, e)
640646 }
641647 }
642648
@@ -646,7 +652,7 @@ private class BufferReleasingInputStream(
646652 } catch {
647653 case e : IOException if detectCorruption =>
648654 IOUtils .closeQuietly(this )
649- iterator.throwFetchFailedException(blockId, address, e)
655+ iterator.throwFetchFailedException(blockId, mapId, address, e)
650656 }
651657 }
652658
@@ -681,9 +687,10 @@ object ShuffleBlockFetcherIterator {
681687 * A request to fetch blocks from a remote BlockManager.
682688 * @param address remote BlockManager to fetch from.
683689 * @param blocks Sequence of tuple, where the first element is the block id,
684- * and the second element is the estimated size, used to calculate bytesInFlight.
690+ * and the second element is the estimated size, used to calculate bytesInFlight,
691+ * the third element is the mapId.
685692 */
686- case class FetchRequest (address : BlockManagerId , blocks : Seq [(BlockId , Long )]) {
693+ case class FetchRequest (address : BlockManagerId , blocks : Seq [(BlockId , Long , Int )]) {
687694 val size = blocks.map(_._2).sum
688695 }
689696
@@ -698,6 +705,7 @@ object ShuffleBlockFetcherIterator {
698705 /**
699706 * Result of a fetch from a remote block successfully.
700707 * @param blockId block id
708+ * @param mapId mapId for this block
701709 * @param address BlockManager that the block was fetched from.
702710 * @param size estimated size of the block. Note that this is NOT the exact bytes.
703711 * Size of remote block is used to calculate bytesInFlight.
@@ -706,6 +714,7 @@ object ShuffleBlockFetcherIterator {
706714 */
707715 private [storage] case class SuccessFetchResult (
708716 blockId : BlockId ,
717+ mapId : Int ,
709718 address : BlockManagerId ,
710719 size : Long ,
711720 buf : ManagedBuffer ,
@@ -717,11 +726,13 @@ object ShuffleBlockFetcherIterator {
717726 /**
718727 * Result of a fetch from a remote block unsuccessfully.
719728 * @param blockId block id
729+ * @param mapId mapId for this block
720730 * @param address BlockManager that the block was attempted to be fetched from
721731 * @param e the failure exception
722732 */
723733 private [storage] case class FailureFetchResult (
724734 blockId : BlockId ,
735+ mapId : Int ,
725736 address : BlockManagerId ,
726737 e : Throwable )
727738 extends FetchResult
0 commit comments