Skip to content

Commit c85a216

Browse files
author
Davies Liu
committed
add tests
1 parent 5c93aaf commit c85a216

File tree

3 files changed

+144
-62
lines changed

3 files changed

+144
-62
lines changed

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
5050
serializerManager.wrapStream,
5151
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
5252
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
53-
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
53+
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
54+
SparkEnv.get.conf.getBoolean("spark.shuffle.tryDecompress", true))
5455

5556
val serializerInstance = dep.serializer.newInstance()
5657

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ final class ShuffleBlockFetcherIterator(
6161
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
6262
streamWrapper: (BlockId, InputStream) => InputStream,
6363
maxBytesInFlight: Long,
64-
maxReqsInFlight: Int)
64+
maxReqsInFlight: Int,
65+
detectCorrupt: Boolean)
6566
extends Iterator[(BlockId, InputStream)] with Logging {
6667

6768
import ShuffleBlockFetcherIterator._
@@ -98,7 +99,7 @@ final class ShuffleBlockFetcherIterator(
9899
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
99100
* in case of a runtime exception when processing the current buffer.
100101
*/
101-
@volatile private[this] var currentResult: FetchResult = null
102+
@volatile private[this] var currentResult: SuccessFetchResult = null
102103

103104
/**
104105
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -113,7 +114,7 @@ final class ShuffleBlockFetcherIterator(
113114
private[this] var reqsInFlight = 0
114115

115116
/** The blocks that can't be decompressed successfully */
116-
private[this] val corruptedBlocks = mutable.HashSet[String]()
117+
private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
117118

118119
private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
119120

@@ -130,9 +131,8 @@ final class ShuffleBlockFetcherIterator(
130131
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
131132
private[storage] def releaseCurrentResultBuffer(): Unit = {
132133
// Release the current buffer if necessary
133-
currentResult match {
134-
case SuccessFetchResult(_, _, _, buf, _) => buf.release()
135-
case _ =>
134+
if (currentResult != null) {
135+
currentResult.buf.release()
136136
}
137137
currentResult = null
138138
}
@@ -315,14 +315,18 @@ final class ShuffleBlockFetcherIterator(
315315

316316
var result: FetchResult = null
317317
var input: InputStream = null
318+
// Take the next fetched result and try to decompress it to detect data corruption,
319+
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
320+
// is also corrupt, so the previous stage could be retried.
321+
// For local shuffle block, throw FailureFetchResult for the first IOException.
318322
while (result == null) {
319323
val startFetchWait = System.currentTimeMillis()
320324
result = results.take()
321325
val stopFetchWait = System.currentTimeMillis()
322326
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
323327

324328
result match {
325-
case SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
329+
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
326330
if (address != blockManager.blockManagerId) {
327331
shuffleMetrics.incRemoteBytesRead(buf.size)
328332
shuffleMetrics.incRemoteBlocksFetched(1)
@@ -337,62 +341,53 @@ final class ShuffleBlockFetcherIterator(
337341
buf.createInputStream()
338342
} catch {
339343
// The exception could only be throwed by local shuffle block
340-
case e: IOException if buf.isInstanceOf[FileSegmentManagedBuffer] =>
344+
case e: IOException =>
345+
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
341346
logError("Failed to create input stream from local block", e)
342347
buf.release()
343-
result = FailureFetchResult(blockId, address, e)
344-
null
348+
throwFetchFailedException(blockId, address, e)
345349
}
346-
if (in != null) {
347-
input = streamWrapper(blockId, in)
348-
// Only copy the stream if it's wrapped by compression or encryption, also the size of
349-
// block is small (the decompressed block is smaller than maxBytesInFlight)
350-
if (!input.eq(in) && size < maxBytesInFlight / 3) {
351-
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
352-
try {
353-
// Decompress the whole block at once to detect any corruption, which could increase
354-
// the memory usage tne potential increase the chance of OOM.
355-
// TODO: manage the memory used here, and spill it into disk in case of OOM.
356-
Utils.copyStream(input, out)
357-
input = out.toChunkedByteBuffer.toInputStream(true)
358-
} catch {
359-
case e: IOException =>
360-
buf.release()
361-
if (buf.isInstanceOf[FileSegmentManagedBuffer]
362-
|| corruptedBlocks.contains(blockId.toString)) {
363-
result = FailureFetchResult(blockId, address, e)
364-
} else {
365-
logWarning(s"got an corrupted block $blockId from $address, fetch again")
366-
fetchRequests += FetchRequest(address, Array((blockId, size)))
367-
result = null
368-
}
369-
} finally {
370-
// TODO: release the buf here (earlier)
371-
in.close()
372-
}
350+
351+
input = streamWrapper(blockId, in)
352+
// Only copy the stream if it's wrapped by compression or encryption, also the size of
353+
// block is small (the decompressed block is smaller than maxBytesInFlight)
354+
if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
355+
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
356+
try {
357+
// Decompress the whole block at once to detect any corruption, which could increase
358+
// the memory usage tne potential increase the chance of OOM.
359+
// TODO: manage the memory used here, and spill it into disk in case of OOM.
360+
Utils.copyStream(input, out)
361+
out.close()
362+
input = out.toChunkedByteBuffer.toInputStream(true)
363+
} catch {
364+
case e: IOException =>
365+
buf.release()
366+
if (buf.isInstanceOf[FileSegmentManagedBuffer]
367+
|| corruptedBlocks.contains(blockId)) {
368+
throwFetchFailedException(blockId, address, e)
369+
} else {
370+
logWarning(s"got an corrupted block $blockId from $address, fetch again")
371+
corruptedBlocks += blockId
372+
fetchRequests += FetchRequest(address, Array((blockId, size)))
373+
result = null
374+
}
375+
} finally {
376+
// TODO: release the buf here to free memory earlier
377+
in.close()
373378
}
374379
}
375380

376-
case _ =>
381+
case FailureFetchResult(blockId, address, e) =>
382+
throwFetchFailedException(blockId, address, e)
377383
}
378384

379385
// Send fetch requests up to maxBytesInFlight
380386
fetchUpToMaxBytes()
381387
}
382-
currentResult = result
383-
384-
result match {
385-
case FailureFetchResult(blockId, address, e) =>
386-
throwFetchFailedException(blockId, address, e)
387-
388-
case SuccessFetchResult(blockId, address, _, buf, _) =>
389-
try {
390-
(result.blockId, new BufferReleasingInputStream(input, this))
391-
} catch {
392-
case NonFatal(t) =>
393-
throwFetchFailedException(blockId, address, t)
394-
}
395-
}
388+
389+
currentResult = result.asInstanceOf[SuccessFetchResult]
390+
(currentResult.blockId, new BufferReleasingInputStream(input, this))
396391
}
397392

398393
private def fetchUpToMaxBytes(): Unit = {

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.InputStream
20+
import java.io.{File, InputStream, IOException}
2121
import java.util.concurrent.Semaphore
2222

2323
import scala.concurrent.ExecutionContext.Implicits.global
@@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester
3131

3232
import org.apache.spark.{SparkFunSuite, TaskContext}
3333
import org.apache.spark.network._
34-
import org.apache.spark.network.buffer.ManagedBuffer
34+
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3535
import org.apache.spark.network.shuffle.BlockFetchingListener
36+
import org.apache.spark.network.util.LimitedInputStream
3637
import org.apache.spark.shuffle.FetchFailedException
3738

3839

@@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
6364
// Create a mock managed buffer for testing
6465
def createMockManagedBuffer(): ManagedBuffer = {
6566
val mockManagedBuffer = mock(classOf[ManagedBuffer])
66-
when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
67+
val in = mock(classOf[InputStream])
68+
when(in.read(any())).thenReturn(1)
69+
when(in.read(any(), any(), any())).thenReturn(1)
70+
when(mockManagedBuffer.createInputStream()).thenReturn(in)
6771
mockManagedBuffer
6872
}
6973

@@ -101,7 +105,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
101105
blocksByAddress,
102106
(_, in) => in,
103107
48 * 1024 * 1024,
104-
Int.MaxValue)
108+
Int.MaxValue,
109+
true)
105110

106111
// 3 local blocks fetched in initialization
107112
verify(blockManager, times(3)).getBlockData(any())
@@ -175,7 +180,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
175180
blocksByAddress,
176181
(_, in) => in,
177182
48 * 1024 * 1024,
178-
Int.MaxValue)
183+
Int.MaxValue,
184+
true)
179185

180186
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
181187
iterator.next()._2.close() // close() first block's input stream
@@ -203,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
203209
// Make sure remote blocks would return
204210
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
205211
val blocks = Map[BlockId, ManagedBuffer](
206-
ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
207-
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
208-
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
212+
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
213+
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
214+
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
209215
)
210216

211217
// Semaphore to coordinate event sequence in two different threads.
@@ -239,7 +245,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
239245
blocksByAddress,
240246
(_, in) => in,
241247
48 * 1024 * 1024,
242-
Int.MaxValue)
248+
Int.MaxValue,
249+
true)
243250

244251
// Continue only after the mock calls onBlockFetchFailure
245252
sem.acquire()
@@ -250,4 +257,83 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
250257
intercept[FetchFailedException] { iterator.next() }
251258
intercept[FetchFailedException] { iterator.next() }
252259
}
260+
261+
test("retry corrupt blocks") {
262+
val blockManager = mock(classOf[BlockManager])
263+
val localBmId = BlockManagerId("test-client", "test-client", 1)
264+
doReturn(localBmId).when(blockManager).blockManagerId
265+
266+
// Make sure remote blocks would return
267+
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
268+
val blocks = Map[BlockId, ManagedBuffer](
269+
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
270+
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
271+
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
272+
)
273+
274+
// Semaphore to coordinate event sequence in two different threads.
275+
val sem = new Semaphore(0)
276+
277+
val corruptStream = mock(classOf[InputStream])
278+
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
279+
val corruptBuffer = mock(classOf[ManagedBuffer])
280+
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
281+
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
282+
283+
val transfer = mock(classOf[BlockTransferService])
284+
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
285+
override def answer(invocation: InvocationOnMock): Unit = {
286+
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
287+
Future {
288+
// Return the first block, and then fail.
289+
listener.onBlockFetchSuccess(
290+
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
291+
listener.onBlockFetchSuccess(
292+
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
293+
listener.onBlockFetchSuccess(
294+
ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
295+
sem.release()
296+
}
297+
}
298+
})
299+
300+
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
301+
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
302+
303+
val taskContext = TaskContext.empty()
304+
val iterator = new ShuffleBlockFetcherIterator(
305+
taskContext,
306+
transfer,
307+
blockManager,
308+
blocksByAddress,
309+
(_, in) => new LimitedInputStream(in, 100),
310+
48 * 1024 * 1024,
311+
Int.MaxValue,
312+
true)
313+
314+
// Continue only after the mock calls onBlockFetchFailure
315+
sem.acquire()
316+
317+
// The first block should be returned without an exception
318+
val (id1, _) = iterator.next()
319+
assert(id1 === ShuffleBlockId(0, 0, 0))
320+
321+
when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
322+
override def answer(invocation: InvocationOnMock): Unit = {
323+
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
324+
Future {
325+
// Return the first block, and then fail.
326+
listener.onBlockFetchSuccess(
327+
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
328+
sem.release()
329+
}
330+
}
331+
})
332+
333+
// The next block is corrupt local block (the second one is corrupt and retried)
334+
intercept[FetchFailedException] { iterator.next() }
335+
336+
sem.acquire()
337+
intercept[FetchFailedException] { iterator.next() }
338+
}
253339
}

0 commit comments

Comments
 (0)