@@ -28,10 +28,10 @@ import org.mockito.Mockito._
2828import org .mockito .invocation .InvocationOnMock
2929import org .mockito .stubbing .Answer
3030
31+ import org .apache .spark .{SparkFunSuite , TaskContextImpl }
3132import org .apache .spark .network ._
3233import org .apache .spark .network .buffer .ManagedBuffer
3334import org .apache .spark .network .shuffle .BlockFetchingListener
34- import org .apache .spark .{SparkFunSuite , TaskContextImpl }
3535
3636
3737class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
@@ -61,11 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
6161 // Create a mock managed buffer for testing
6262 def createMockManagedBuffer (): ManagedBuffer = {
6363 val mockManagedBuffer = mock(classOf [ManagedBuffer ])
64- when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer [InputStream ] {
65- override def answer (invocation : InvocationOnMock ): InputStream = {
66- mock(classOf [InputStream ])
67- }
68- })
64+ when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf [InputStream ]))
6965 mockManagedBuffer
7066 }
7167
@@ -76,19 +72,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
7672
7773 // Make sure blockManager.getBlockData would return the blocks
7874 val localBlocks = Map [BlockId , ManagedBuffer ](
79- ShuffleBlockId (0 , 0 , 0 ) -> mock( classOf [ ManagedBuffer ] ),
80- ShuffleBlockId (0 , 1 , 0 ) -> mock( classOf [ ManagedBuffer ] ),
81- ShuffleBlockId (0 , 2 , 0 ) -> mock( classOf [ ManagedBuffer ] ))
75+ ShuffleBlockId (0 , 0 , 0 ) -> createMockManagedBuffer( ),
76+ ShuffleBlockId (0 , 1 , 0 ) -> createMockManagedBuffer( ),
77+ ShuffleBlockId (0 , 2 , 0 ) -> createMockManagedBuffer( ))
8278 localBlocks.foreach { case (blockId, buf) =>
8379 doReturn(buf).when(blockManager).getBlockData(meq(blockId))
8480 }
8581
8682 // Make sure remote blocks would return
8783 val remoteBmId = BlockManagerId (" test-client-1" , " test-client-1" , 2 )
8884 val remoteBlocks = Map [BlockId , ManagedBuffer ](
89- ShuffleBlockId (0 , 3 , 0 ) -> mock(classOf [ManagedBuffer ]),
90- ShuffleBlockId (0 , 4 , 0 ) -> mock(classOf [ManagedBuffer ])
91- )
85+ ShuffleBlockId (0 , 3 , 0 ) -> createMockManagedBuffer(),
86+ ShuffleBlockId (0 , 4 , 0 ) -> createMockManagedBuffer())
9287
9388 val transfer = createMockTransfer(remoteBlocks)
9489
@@ -109,13 +104,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
109104
110105 for (i <- 0 until 5 ) {
111106 assert(iterator.hasNext, s " iterator should have 5 elements but actually has $i elements " )
112- val (blockId, subIterator ) = iterator.next()
113- assert(subIterator .isSuccess,
107+ val (blockId, inputStream ) = iterator.next()
108+ assert(inputStream .isSuccess,
114109 s " iterator should have 5 elements defined but actually has $i elements " )
115110
116111 // Make sure we release buffers when a wrapped input stream is closed.
117112 val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
118- val wrappedInputStream = new BufferReleasingInputStream (mock( classOf [ InputStream ]) , iterator)
113+ val wrappedInputStream = new BufferReleasingInputStream (inputStream.get , iterator)
119114 verify(mockBuf, times(0 )).release()
120115 wrappedInputStream.close()
121116 verify(mockBuf, times(1 )).release()
0 commit comments