Skip to content

Commit cf33a86

Browse files
Davies Liuzsxwing
authored andcommitted
[SPARK-4105] retry the fetch or stage if shuffle block is corrupt
## What changes were proposed in this pull request? There is an outstanding issue that existed for a long time: Sometimes the shuffle blocks are corrupt and can't be decompressed. We recently hit this in three different workloads, sometimes we can reproduce it by every try, sometimes can't. I also found that when the corruption happened, the beginning and end of the blocks are correct, the corruption happen in the middle. There was one case that the string of block id is corrupt by one character. It seems that it's very likely the corruption is introduced by some weird machine/hardware, also the checksum (16 bits) in TCP is not strong enough to identify all the corruption. Unfortunately, Spark does not have checksum for shuffle blocks or broadcast, the job will fail if any corruption happen in the shuffle block from disk, or broadcast blocks during network. This PR try to detect the corruption after fetching shuffle blocks by decompressing them, because most of the compression already have checksum in them. It will retry the block, or failed with FetchFailure, so the previous stage could be retried on different (still random) machines. Checksum for broadcast will be added by another PR. ## How was this patch tested? Added unit tests Author: Davies Liu <[email protected]> Closes #15923 from davies/detect_corrupt.
1 parent d60ab5f commit cf33a86

File tree

4 files changed

+263
-57
lines changed

4 files changed

+263
-57
lines changed

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C](
4242

4343
/** Read the combined key-values for this reduce task */
4444
override def read(): Iterator[Product2[K, C]] = {
45-
val blockFetcherItr = new ShuffleBlockFetcherIterator(
45+
val wrappedStreams = new ShuffleBlockFetcherIterator(
4646
context,
4747
blockManager.shuffleClient,
4848
blockManager,
4949
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
50+
serializerManager.wrapStream,
5051
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
5152
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
52-
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
53-
54-
// Wrap the streams for compression and encryption based on configuration
55-
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
56-
serializerManager.wrapStream(blockId, inputStream)
57-
}
53+
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
54+
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
5855

5956
val serializerInstance = dep.serializer.newInstance()
6057

6158
// Create a key/value iterator for each stream
62-
val recordIter = wrappedStreams.flatMap { wrappedStream =>
59+
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
6360
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
6461
// NextIterator. The NextIterator makes sure that close() is called on the
6562
// underlying InputStream when all records have been read.

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 94 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,21 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.InputStream
20+
import java.io.{InputStream, IOException}
21+
import java.nio.ByteBuffer
2122
import java.util.concurrent.LinkedBlockingQueue
2223
import javax.annotation.concurrent.GuardedBy
2324

25+
import scala.collection.mutable
2426
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
25-
import scala.util.control.NonFatal
2627

2728
import org.apache.spark.{SparkException, TaskContext}
2829
import org.apache.spark.internal.Logging
29-
import org.apache.spark.network.buffer.ManagedBuffer
30+
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3031
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
3132
import org.apache.spark.shuffle.FetchFailedException
3233
import org.apache.spark.util.Utils
34+
import org.apache.spark.util.io.ChunkedByteBufferOutputStream
3335

3436
/**
3537
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -47,17 +49,21 @@ import org.apache.spark.util.Utils
4749
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
4850
* For each block we also require the size (in bytes as a long field) in
4951
* order to throttle the memory usage.
52+
* @param streamWrapper A function to wrap the returned input stream.
5053
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
5154
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
55+
* @param detectCorrupt whether to detect any corruption in fetched blocks.
5256
*/
5357
private[spark]
5458
final class ShuffleBlockFetcherIterator(
5559
context: TaskContext,
5660
shuffleClient: ShuffleClient,
5761
blockManager: BlockManager,
5862
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
63+
streamWrapper: (BlockId, InputStream) => InputStream,
5964
maxBytesInFlight: Long,
60-
maxReqsInFlight: Int)
65+
maxReqsInFlight: Int,
66+
detectCorrupt: Boolean)
6167
extends Iterator[(BlockId, InputStream)] with Logging {
6268

6369
import ShuffleBlockFetcherIterator._
@@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
94100
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
95101
* in case of a runtime exception when processing the current buffer.
96102
*/
97-
@volatile private[this] var currentResult: FetchResult = null
103+
@volatile private[this] var currentResult: SuccessFetchResult = null
98104

99105
/**
100106
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
108114
/** Current number of requests in flight */
109115
private[this] var reqsInFlight = 0
110116

117+
/**
118+
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
119+
* at most once for those corrupted blocks.
120+
*/
121+
private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
122+
111123
private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
112124

113125
/**
@@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
123135
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
124136
private[storage] def releaseCurrentResultBuffer(): Unit = {
125137
// Release the current buffer if necessary
126-
currentResult match {
127-
case SuccessFetchResult(_, _, _, buf, _) => buf.release()
128-
case _ =>
138+
if (currentResult != null) {
139+
currentResult.buf.release()
129140
}
130141
currentResult = null
131142
}
@@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator(
305316
*/
306317
override def next(): (BlockId, InputStream) = {
307318
numBlocksProcessed += 1
308-
val startFetchWait = System.currentTimeMillis()
309-
currentResult = results.take()
310-
val result = currentResult
311-
val stopFetchWait = System.currentTimeMillis()
312-
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
313-
314-
result match {
315-
case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
316-
if (address != blockManager.blockManagerId) {
317-
shuffleMetrics.incRemoteBytesRead(buf.size)
318-
shuffleMetrics.incRemoteBlocksFetched(1)
319-
}
320-
bytesInFlight -= size
321-
if (isNetworkReqDone) {
322-
reqsInFlight -= 1
323-
logDebug("Number of requests in flight " + reqsInFlight)
324-
}
325-
case _ =>
326-
}
327-
// Send fetch requests up to maxBytesInFlight
328-
fetchUpToMaxBytes()
329319

330-
result match {
331-
case FailureFetchResult(blockId, address, e) =>
332-
throwFetchFailedException(blockId, address, e)
320+
var result: FetchResult = null
321+
var input: InputStream = null
322+
// Take the next fetched result and try to decompress it to detect data corruption,
323+
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
324+
// is also corrupt, so the previous stage could be retried.
325+
// For local shuffle block, throw FailureFetchResult for the first IOException.
326+
while (result == null) {
327+
val startFetchWait = System.currentTimeMillis()
328+
result = results.take()
329+
val stopFetchWait = System.currentTimeMillis()
330+
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
333331

334-
case SuccessFetchResult(blockId, address, _, buf, _) =>
335-
try {
336-
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
337-
} catch {
338-
case NonFatal(t) =>
339-
throwFetchFailedException(blockId, address, t)
340-
}
332+
result match {
333+
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
334+
if (address != blockManager.blockManagerId) {
335+
shuffleMetrics.incRemoteBytesRead(buf.size)
336+
shuffleMetrics.incRemoteBlocksFetched(1)
337+
}
338+
bytesInFlight -= size
339+
if (isNetworkReqDone) {
340+
reqsInFlight -= 1
341+
logDebug("Number of requests in flight " + reqsInFlight)
342+
}
343+
344+
val in = try {
345+
buf.createInputStream()
346+
} catch {
347+
// The exception could only be throwed by local shuffle block
348+
case e: IOException =>
349+
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
350+
logError("Failed to create input stream from local block", e)
351+
buf.release()
352+
throwFetchFailedException(blockId, address, e)
353+
}
354+
355+
input = streamWrapper(blockId, in)
356+
// Only copy the stream if it's wrapped by compression or encryption, also the size of
357+
// block is small (the decompressed block is smaller than maxBytesInFlight)
358+
if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
359+
val originalInput = input
360+
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
361+
try {
362+
// Decompress the whole block at once to detect any corruption, which could increase
363+
// the memory usage tne potential increase the chance of OOM.
364+
// TODO: manage the memory used here, and spill it into disk in case of OOM.
365+
Utils.copyStream(input, out)
366+
out.close()
367+
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
368+
} catch {
369+
case e: IOException =>
370+
buf.release()
371+
if (buf.isInstanceOf[FileSegmentManagedBuffer]
372+
|| corruptedBlocks.contains(blockId)) {
373+
throwFetchFailedException(blockId, address, e)
374+
} else {
375+
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
376+
corruptedBlocks += blockId
377+
fetchRequests += FetchRequest(address, Array((blockId, size)))
378+
result = null
379+
}
380+
} finally {
381+
// TODO: release the buf here to free memory earlier
382+
originalInput.close()
383+
in.close()
384+
}
385+
}
386+
387+
case FailureFetchResult(blockId, address, e) =>
388+
throwFetchFailedException(blockId, address, e)
389+
}
390+
391+
// Send fetch requests up to maxBytesInFlight
392+
fetchUpToMaxBytes()
341393
}
394+
395+
currentResult = result.asInstanceOf[SuccessFetchResult]
396+
(currentResult.blockId, new BufferReleasingInputStream(input, this))
342397
}
343398

344399
private def fetchUpToMaxBytes(): Unit = {

core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
151151
* @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream
152152
* in order to close any memory-mapped files which back the buffer.
153153
*/
154-
private class ChunkedByteBufferInputStream(
154+
private[spark] class ChunkedByteBufferInputStream(
155155
var chunkedByteBuffer: ChunkedByteBuffer,
156156
dispose: Boolean)
157157
extends InputStream {

0 commit comments

Comments
 (0)