Skip to content

Commit 290f1eb

Browse files
committed
Added test for HashShuffleReader.read()
1 parent 5186da0 commit 290f1eb

File tree

3 files changed

+164
-10
lines changed

3 files changed

+164
-10
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,26 @@ package org.apache.spark.shuffle.hash
2020
import java.io.InputStream
2121

2222
import scala.collection.mutable.{ArrayBuffer, HashMap}
23-
import scala.util.{Failure, Success, Try}
23+
import scala.util.{Failure, Success}
2424

2525
import org.apache.spark._
2626
import org.apache.spark.shuffle.FetchFailedException
27-
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
27+
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
28+
ShuffleBlockId}
2829

2930
private[hash] object BlockStoreShuffleFetcher extends Logging {
3031
def fetchBlockStreams(
3132
shuffleId: Int,
3233
reduceId: Int,
33-
context: TaskContext)
34+
context: TaskContext,
35+
blockManager: BlockManager,
36+
mapOutputTracker: MapOutputTracker)
3437
: Iterator[(BlockId, InputStream)] =
3538
{
3639
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
37-
val blockManager = SparkEnv.get.blockManager
3840

3941
val startTime = System.currentTimeMillis
40-
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
42+
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
4143
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
4244
shuffleId, reduceId, System.currentTimeMillis - startTime))
4345

@@ -53,7 +55,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
5355

5456
val blockFetcherItr = new ShuffleBlockFetcherIterator(
5557
context,
56-
SparkEnv.get.blockManager.shuffleClient,
58+
blockManager.shuffleClient,
5759
blockManager,
5860
blocksByAddress,
5961
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,31 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
20+
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
2121
import org.apache.spark.serializer.Serializer
2222
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
23+
import org.apache.spark.storage.BlockManager
2324
import org.apache.spark.util.CompletionIterator
2425
import org.apache.spark.util.collection.ExternalSorter
2526

2627
private[spark] class HashShuffleReader[K, C](
2728
handle: BaseShuffleHandle[K, _, C],
2829
startPartition: Int,
2930
endPartition: Int,
30-
context: TaskContext)
31+
context: TaskContext,
32+
blockManager: BlockManager = SparkEnv.get.blockManager,
33+
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
3134
extends ShuffleReader[K, C]
3235
{
3336
require(endPartition == startPartition + 1,
3437
"Hash shuffle currently only supports fetching one partition")
3538

3639
private val dep = handle.dependency
37-
private val blockManager = SparkEnv.get.blockManager
3840

3941
/** Read the combined key-values for this reduce task */
4042
override def read(): Iterator[Product2[K, C]] = {
4143
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
42-
handle.shuffleId, startPartition, context)
44+
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
4345

4446
// Wrap the streams for compression based on configuration
4547
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.hash
19+
20+
import java.io.{ByteArrayOutputStream, InputStream}
21+
import java.nio.ByteBuffer
22+
23+
import org.mockito.Matchers.{eq => meq, _}
24+
import org.mockito.Mockito.{mock, when}
25+
import org.mockito.invocation.InvocationOnMock
26+
import org.mockito.stubbing.Answer
27+
28+
import org.apache.spark._
29+
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
30+
import org.apache.spark.serializer.JavaSerializer
31+
import org.apache.spark.shuffle.BaseShuffleHandle
32+
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
33+
34+
/**
35+
* Wrapper for a managed buffer that keeps track of how many times retain and release are called.
36+
*
37+
* We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
38+
* is final (final classes cannot be spied on).
39+
*/
40+
class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
41+
var callsToRetain = 0
42+
var callsToRelease = 0
43+
44+
override def size() = underlyingBuffer.size()
45+
override def nioByteBuffer() = underlyingBuffer.nioByteBuffer()
46+
override def createInputStream() = underlyingBuffer.createInputStream()
47+
override def convertToNetty() = underlyingBuffer.convertToNetty()
48+
49+
override def retain(): ManagedBuffer = {
50+
callsToRetain += 1
51+
underlyingBuffer.retain()
52+
}
53+
override def release(): ManagedBuffer = {
54+
callsToRelease += 1
55+
underlyingBuffer.release()
56+
}
57+
}
58+
59+
class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
60+
61+
/**
62+
* This test makes sure that, when data is read from a HashShuffleReader, the underlying
63+
* ManagedBuffers that contain the data are eventually released.
64+
*/
65+
test("read() releases resources on completion") {
66+
val testConf = new SparkConf(false)
67+
// Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
68+
// shuffle code calls SparkEnv.get()).
69+
sc = new SparkContext("local", "test", testConf)
70+
71+
val reduceId = 15
72+
val shuffleId = 22
73+
val numMaps = 6
74+
val keyValuePairsPerMap = 10
75+
val serializer = new JavaSerializer(testConf)
76+
77+
// Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
78+
// can ensure retain() and release() are properly called.
79+
val blockManager = mock(classOf[BlockManager])
80+
81+
// Create a return function to use for the mocked wrapForCompression method that just returns
82+
// the original input stream.
83+
val dummyCompressionFunction = new Answer[InputStream] {
84+
override def answer(invocation: InvocationOnMock) =
85+
invocation.getArguments()(1).asInstanceOf[InputStream]
86+
}
87+
88+
// Create a buffer with some randomly generated key-value pairs to use as the shuffle data
89+
// from each mappers (all mappers return the same shuffle data).
90+
val byteOutputStream = new ByteArrayOutputStream()
91+
val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
92+
(0 until keyValuePairsPerMap).foreach { i =>
93+
serializationStream.writeKey(i)
94+
serializationStream.writeValue(2*i)
95+
}
96+
97+
// Setup the mocked BlockManager to return RecordingManagedBuffers.
98+
val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
99+
when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
100+
val buffers = (0 until numMaps).map { mapId =>
101+
// Create a ManagedBuffer with the shuffle data.
102+
val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
103+
val managedBuffer = new RecordingManagedBuffer(nioBuffer)
104+
105+
// Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
106+
// fetch shuffle data.
107+
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
108+
when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
109+
when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
110+
.thenAnswer(dummyCompressionFunction)
111+
112+
managedBuffer
113+
}
114+
115+
// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
116+
// shuffle data to read.
117+
val mapOutputTracker = mock(classOf[MapOutputTracker])
118+
// Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
119+
// for the code to read data over the network.
120+
val statuses: Array[(BlockManagerId, Long)] =
121+
Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size()))
122+
when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)
123+
124+
// Create a mocked shuffle handle to pass into HashShuffleReader.
125+
val shuffleHandle = {
126+
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
127+
when(dependency.serializer).thenReturn(Some(serializer))
128+
when(dependency.aggregator).thenReturn(None)
129+
when(dependency.keyOrdering).thenReturn(None)
130+
new BaseShuffleHandle(shuffleId, numMaps, dependency)
131+
}
132+
133+
val shuffleReader = new HashShuffleReader(
134+
shuffleHandle,
135+
reduceId,
136+
reduceId + 1,
137+
new TaskContextImpl(0, 0, 0, 0, null),
138+
blockManager,
139+
mapOutputTracker)
140+
141+
assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
142+
143+
// Calling .length above will have exhausted the iterator; make sure that exhausting the
144+
// iterator caused retain and release to be called on each buffer.
145+
buffers.foreach { buffer =>
146+
assert(buffer.callsToRetain === 1)
147+
assert(buffer.callsToRelease === 1)
148+
}
149+
}
150+
}

0 commit comments

Comments
 (0)