1717
1818package org .apache .spark .storage
1919
20- import java .io .InputStream
20+ import java .io .{ File , InputStream , IOException }
2121import java .util .concurrent .Semaphore
2222
2323import scala .concurrent .ExecutionContext .Implicits .global
@@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester
3131
3232import org .apache .spark .{SparkFunSuite , TaskContext }
3333import org .apache .spark .network ._
34- import org .apache .spark .network .buffer .ManagedBuffer
34+ import org .apache .spark .network .buffer .{ FileSegmentManagedBuffer , ManagedBuffer }
3535import org .apache .spark .network .shuffle .BlockFetchingListener
36+ import org .apache .spark .network .util .LimitedInputStream
3637import org .apache .spark .shuffle .FetchFailedException
3738
3839
@@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
6364 // Create a mock managed buffer for testing
6465 def createMockManagedBuffer (): ManagedBuffer = {
6566 val mockManagedBuffer = mock(classOf [ManagedBuffer ])
66- when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf [InputStream ]))
67+ val in = mock(classOf [InputStream ])
68+ when(in.read(any())).thenReturn(1 )
69+ when(in.read(any(), any(), any())).thenReturn(1 )
70+ when(mockManagedBuffer.createInputStream()).thenReturn(in)
6771 mockManagedBuffer
6872 }
6973
@@ -101,7 +105,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
101105 blocksByAddress,
102106 (_, in) => in,
103107 48 * 1024 * 1024 ,
104- Int .MaxValue )
108+ Int .MaxValue ,
109+ true )
105110
106111 // 3 local blocks fetched in initialization
107112 verify(blockManager, times(3 )).getBlockData(any())
@@ -175,7 +180,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
175180 blocksByAddress,
176181 (_, in) => in,
177182 48 * 1024 * 1024 ,
178- Int .MaxValue )
183+ Int .MaxValue ,
184+ true )
179185
180186 verify(blocks(ShuffleBlockId (0 , 0 , 0 )), times(0 )).release()
181187 iterator.next()._2.close() // close() first block's input stream
@@ -203,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
203209 // Make sure remote blocks would return
204210 val remoteBmId = BlockManagerId (" test-client-1" , " test-client-1" , 2 )
205211 val blocks = Map [BlockId , ManagedBuffer ](
206- ShuffleBlockId (0 , 0 , 0 ) -> mock( classOf [ ManagedBuffer ] ),
207- ShuffleBlockId (0 , 1 , 0 ) -> mock( classOf [ ManagedBuffer ] ),
208- ShuffleBlockId (0 , 2 , 0 ) -> mock( classOf [ ManagedBuffer ] )
212+ ShuffleBlockId (0 , 0 , 0 ) -> createMockManagedBuffer( ),
213+ ShuffleBlockId (0 , 1 , 0 ) -> createMockManagedBuffer( ),
214+ ShuffleBlockId (0 , 2 , 0 ) -> createMockManagedBuffer( )
209215 )
210216
211217 // Semaphore to coordinate event sequence in two different threads.
@@ -239,7 +245,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
239245 blocksByAddress,
240246 (_, in) => in,
241247 48 * 1024 * 1024 ,
242- Int .MaxValue )
248+ Int .MaxValue ,
249+ true )
243250
244251 // Continue only after the mock calls onBlockFetchFailure
245252 sem.acquire()
@@ -250,4 +257,83 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
250257 intercept[FetchFailedException ] { iterator.next() }
251258 intercept[FetchFailedException ] { iterator.next() }
252259 }
260+
261+ test(" retry corrupt blocks" ) {
262+ val blockManager = mock(classOf [BlockManager ])
263+ val localBmId = BlockManagerId (" test-client" , " test-client" , 1 )
264+ doReturn(localBmId).when(blockManager).blockManagerId
265+
266+ // Make sure remote blocks would return
267+ val remoteBmId = BlockManagerId (" test-client-1" , " test-client-1" , 2 )
268+ val blocks = Map [BlockId , ManagedBuffer ](
269+ ShuffleBlockId (0 , 0 , 0 ) -> createMockManagedBuffer(),
270+ ShuffleBlockId (0 , 1 , 0 ) -> createMockManagedBuffer(),
271+ ShuffleBlockId (0 , 2 , 0 ) -> createMockManagedBuffer()
272+ )
273+
274+ // Semaphore to coordinate event sequence in two different threads.
275+ val sem = new Semaphore (0 )
276+
277+ val corruptStream = mock(classOf [InputStream ])
278+ when(corruptStream.read(any(), any(), any())).thenThrow(new IOException (" corrupt" ))
279+ val corruptBuffer = mock(classOf [ManagedBuffer ])
280+ when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
281+ val corruptLocalBuffer = new FileSegmentManagedBuffer (null , new File (" a" ), 0 , 100 )
282+
283+ val transfer = mock(classOf [BlockTransferService ])
284+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer [Unit ] {
285+ override def answer (invocation : InvocationOnMock ): Unit = {
286+ val listener = invocation.getArguments()(4 ).asInstanceOf [BlockFetchingListener ]
287+ Future {
288+ // Return the first block, and then fail.
289+ listener.onBlockFetchSuccess(
290+ ShuffleBlockId (0 , 0 , 0 ).toString, blocks(ShuffleBlockId (0 , 0 , 0 )))
291+ listener.onBlockFetchSuccess(
292+ ShuffleBlockId (0 , 1 , 0 ).toString, corruptBuffer)
293+ listener.onBlockFetchSuccess(
294+ ShuffleBlockId (0 , 2 , 0 ).toString, corruptLocalBuffer)
295+ sem.release()
296+ }
297+ }
298+ })
299+
300+ val blocksByAddress = Seq [(BlockManagerId , Seq [(BlockId , Long )])](
301+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1 .asInstanceOf [Long ])).toSeq))
302+
303+ val taskContext = TaskContext .empty()
304+ val iterator = new ShuffleBlockFetcherIterator (
305+ taskContext,
306+ transfer,
307+ blockManager,
308+ blocksByAddress,
309+ (_, in) => new LimitedInputStream (in, 100 ),
310+ 48 * 1024 * 1024 ,
311+ Int .MaxValue ,
312+ true )
313+
314+ // Continue only after the mock calls onBlockFetchFailure
315+ sem.acquire()
316+
317+ // The first block should be returned without an exception
318+ val (id1, _) = iterator.next()
319+ assert(id1 === ShuffleBlockId (0 , 0 , 0 ))
320+
321+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer [Unit ] {
322+ override def answer (invocation : InvocationOnMock ): Unit = {
323+ val listener = invocation.getArguments()(4 ).asInstanceOf [BlockFetchingListener ]
324+ Future {
325+ // Return the first block, and then fail.
326+ listener.onBlockFetchSuccess(
327+ ShuffleBlockId (0 , 1 , 0 ).toString, corruptBuffer)
328+ sem.release()
329+ }
330+ }
331+ })
332+
333+ // The next block is corrupt local block (the second one is corrupt and retried)
334+ intercept[FetchFailedException ] { iterator.next() }
335+
336+ sem.acquire()
337+ intercept[FetchFailedException ] { iterator.next() }
338+ }
253339}
0 commit comments