1717
1818package org .apache .spark .storage
1919
20- import java .io .InputStream
20+ import java .io .{InputStream , IOException }
21+ import java .nio .ByteBuffer
2122import java .util .concurrent .LinkedBlockingQueue
2223import javax .annotation .concurrent .GuardedBy
2324
25+ import scala .collection .mutable
2426import scala .collection .mutable .{ArrayBuffer , HashSet , Queue }
25- import scala .util .control .NonFatal
2627
2728import org .apache .spark .{SparkException , TaskContext }
2829import org .apache .spark .internal .Logging
29- import org .apache .spark .network .buffer .ManagedBuffer
30+ import org .apache .spark .network .buffer .{ FileSegmentManagedBuffer , ManagedBuffer }
3031import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
3132import org .apache .spark .shuffle .FetchFailedException
3233import org .apache .spark .util .Utils
34+ import org .apache .spark .util .io .ChunkedByteBufferOutputStream
3335
3436/**
3537 * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -47,17 +49,21 @@ import org.apache.spark.util.Utils
4749 * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId ]].
4850 * For each block we also require the size (in bytes as a long field) in
4951 * order to throttle the memory usage.
52+ * @param streamWrapper A function to wrap the returned input stream.
5053 * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
5154 * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
55+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
5256 */
5357private [spark]
5458final class ShuffleBlockFetcherIterator (
5559 context : TaskContext ,
5660 shuffleClient : ShuffleClient ,
5761 blockManager : BlockManager ,
5862 blocksByAddress : Seq [(BlockManagerId , Seq [(BlockId , Long )])],
63+ streamWrapper : (BlockId , InputStream ) => InputStream ,
5964 maxBytesInFlight : Long ,
60- maxReqsInFlight : Int )
65+ maxReqsInFlight : Int ,
66+ detectCorrupt : Boolean )
6167 extends Iterator [(BlockId , InputStream )] with Logging {
6268
6369 import ShuffleBlockFetcherIterator ._
@@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
94100 * Current [[FetchResult ]] being processed. We track this so we can release the current buffer
95101 * in case of a runtime exception when processing the current buffer.
96102 */
97- @ volatile private [this ] var currentResult : FetchResult = null
103+ @ volatile private [this ] var currentResult : SuccessFetchResult = null
98104
99105 /**
100106 * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
108114 /** Current number of requests in flight */
109115 private [this ] var reqsInFlight = 0
110116
117+ /**
118+ * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
119+ * at most once for those corrupted blocks.
120+ */
121+ private [this ] val corruptedBlocks = mutable.HashSet [BlockId ]()
122+
111123 private [this ] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
112124
113125 /**
@@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
123135 // The currentResult is set to null to prevent releasing the buffer again on cleanup()
124136 private [storage] def releaseCurrentResultBuffer (): Unit = {
125137 // Release the current buffer if necessary
126- currentResult match {
127- case SuccessFetchResult (_, _, _, buf, _) => buf.release()
128- case _ =>
138+ if (currentResult != null ) {
139+ currentResult.buf.release()
129140 }
130141 currentResult = null
131142 }
@@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator(
305316 */
306317 override def next (): (BlockId , InputStream ) = {
307318 numBlocksProcessed += 1
308- val startFetchWait = System .currentTimeMillis()
309- currentResult = results.take()
310- val result = currentResult
311- val stopFetchWait = System .currentTimeMillis()
312- shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
313-
314- result match {
315- case SuccessFetchResult (_, address, size, buf, isNetworkReqDone) =>
316- if (address != blockManager.blockManagerId) {
317- shuffleMetrics.incRemoteBytesRead(buf.size)
318- shuffleMetrics.incRemoteBlocksFetched(1 )
319- }
320- bytesInFlight -= size
321- if (isNetworkReqDone) {
322- reqsInFlight -= 1
323- logDebug(" Number of requests in flight " + reqsInFlight)
324- }
325- case _ =>
326- }
327- // Send fetch requests up to maxBytesInFlight
328- fetchUpToMaxBytes()
329319
330- result match {
331- case FailureFetchResult (blockId, address, e) =>
332- throwFetchFailedException(blockId, address, e)
320+ var result : FetchResult = null
321+ var input : InputStream = null
322+ // Take the next fetched result and try to decompress it to detect data corruption,
323+ // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
324+ // is also corrupt, so the previous stage could be retried.
325+ // For local shuffle block, throw FailureFetchResult for the first IOException.
326+ while (result == null ) {
327+ val startFetchWait = System .currentTimeMillis()
328+ result = results.take()
329+ val stopFetchWait = System .currentTimeMillis()
330+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
333331
334- case SuccessFetchResult (blockId, address, _, buf, _) =>
335- try {
336- (result.blockId, new BufferReleasingInputStream (buf.createInputStream(), this ))
337- } catch {
338- case NonFatal (t) =>
339- throwFetchFailedException(blockId, address, t)
340- }
332+ result match {
333+ case r @ SuccessFetchResult (blockId, address, size, buf, isNetworkReqDone) =>
334+ if (address != blockManager.blockManagerId) {
335+ shuffleMetrics.incRemoteBytesRead(buf.size)
336+ shuffleMetrics.incRemoteBlocksFetched(1 )
337+ }
338+ bytesInFlight -= size
339+ if (isNetworkReqDone) {
340+ reqsInFlight -= 1
341+ logDebug(" Number of requests in flight " + reqsInFlight)
342+ }
343+
344+ val in = try {
345+ buf.createInputStream()
346+ } catch {
347+ // The exception could only be throwed by local shuffle block
348+ case e : IOException =>
349+ assert(buf.isInstanceOf [FileSegmentManagedBuffer ])
350+ logError(" Failed to create input stream from local block" , e)
351+ buf.release()
352+ throwFetchFailedException(blockId, address, e)
353+ }
354+
355+ input = streamWrapper(blockId, in)
356+ // Only copy the stream if it's wrapped by compression or encryption, also the size of
357+ // block is small (the decompressed block is smaller than maxBytesInFlight)
358+ if (detectCorrupt && ! input.eq(in) && size < maxBytesInFlight / 3 ) {
359+ val originalInput = input
360+ val out = new ChunkedByteBufferOutputStream (64 * 1024 , ByteBuffer .allocate)
361+ try {
362+ // Decompress the whole block at once to detect any corruption, which could increase
363+ // the memory usage tne potential increase the chance of OOM.
364+ // TODO: manage the memory used here, and spill it into disk in case of OOM.
365+ Utils .copyStream(input, out)
366+ out.close()
367+ input = out.toChunkedByteBuffer.toInputStream(dispose = true )
368+ } catch {
369+ case e : IOException =>
370+ buf.release()
371+ if (buf.isInstanceOf [FileSegmentManagedBuffer ]
372+ || corruptedBlocks.contains(blockId)) {
373+ throwFetchFailedException(blockId, address, e)
374+ } else {
375+ logWarning(s " got an corrupted block $blockId from $address, fetch again " , e)
376+ corruptedBlocks += blockId
377+ fetchRequests += FetchRequest (address, Array ((blockId, size)))
378+ result = null
379+ }
380+ } finally {
381+ // TODO: release the buf here to free memory earlier
382+ originalInput.close()
383+ in.close()
384+ }
385+ }
386+
387+ case FailureFetchResult (blockId, address, e) =>
388+ throwFetchFailedException(blockId, address, e)
389+ }
390+
391+ // Send fetch requests up to maxBytesInFlight
392+ fetchUpToMaxBytes()
341393 }
394+
395+ currentResult = result.asInstanceOf [SuccessFetchResult ]
396+ (currentResult.blockId, new BufferReleasingInputStream (input, this ))
342397 }
343398
344399 private def fetchUpToMaxBytes (): Unit = {
0 commit comments