| 
17 | 17 | 
 
  | 
18 | 18 | package org.apache.spark.shuffle.hash  | 
19 | 19 | 
 
  | 
20 |  | -import java.io._  | 
21 |  | -import java.nio.ByteBuffer  | 
 | 20 | +import java.io.{File, FileWriter}  | 
22 | 21 | 
 
  | 
23 | 22 | import scala.language.reflectiveCalls  | 
24 | 23 | 
 
  | 
25 |  | -import org.mockito.Matchers.any  | 
26 |  | -import org.mockito.Mockito._  | 
27 |  | -import org.mockito.invocation.InvocationOnMock  | 
28 |  | -import org.mockito.stubbing.Answer  | 
29 |  | - | 
30 |  | -import org.apache.spark._  | 
31 |  | -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics}  | 
 | 24 | +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}  | 
 | 25 | +import org.apache.spark.executor.ShuffleWriteMetrics  | 
32 | 26 | import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}  | 
33 |  | -import org.apache.spark.serializer._  | 
34 |  | -import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver}  | 
35 |  | -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment}  | 
 | 27 | +import org.apache.spark.serializer.JavaSerializer  | 
 | 28 | +import org.apache.spark.shuffle.FileShuffleBlockResolver  | 
 | 29 | +import org.apache.spark.storage.{ShuffleBlockId, FileSegment}  | 
36 | 30 | 
 
  | 
37 | 31 | class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {  | 
38 | 32 |   private val testConf = new SparkConf(false)  | 
@@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {  | 
113 | 107 |     for (i <- 0 until numBytes) writer.write(i)  | 
114 | 108 |     writer.close()  | 
115 | 109 |   }  | 
116 |  | - | 
117 |  | -  test("HashShuffleReader.read() releases resources and tracks metrics") {  | 
118 |  | -    val shuffleId = 1  | 
119 |  | -    val numMaps = 2  | 
120 |  | -    val numKeyValuePairs = 10  | 
121 |  | - | 
122 |  | -    val mockContext = mock(classOf[TaskContext])  | 
123 |  | - | 
124 |  | -    val mockTaskMetrics = mock(classOf[TaskMetrics])  | 
125 |  | -    val mockReadMetrics = mock(classOf[ShuffleReadMetrics])  | 
126 |  | -    when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics)  | 
127 |  | -    when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics)  | 
128 |  | - | 
129 |  | -    val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher])  | 
130 |  | - | 
131 |  | -    val mockDep = mock(classOf[ShuffleDependency[_, _, _]])  | 
132 |  | -    when(mockDep.keyOrdering).thenReturn(None)  | 
133 |  | -    when(mockDep.aggregator).thenReturn(None)  | 
134 |  | -    when(mockDep.serializer).thenReturn(Some(new Serializer {  | 
135 |  | -      override def newInstance(): SerializerInstance = new SerializerInstance {  | 
136 |  | - | 
137 |  | -        override def deserializeStream(s: InputStream): DeserializationStream =  | 
138 |  | -          new DeserializationStream {  | 
139 |  | -            override def readObject[T: ClassManifest](): T = null.asInstanceOf[T]  | 
140 |  | - | 
141 |  | -            override def close(): Unit = s.close()  | 
142 |  | - | 
143 |  | -            private val values = {  | 
144 |  | -              for (i <- 0 to numKeyValuePairs * 2) yield i  | 
145 |  | -            }.iterator  | 
146 |  | - | 
147 |  | -            private def getValueOrEOF(): Int = {  | 
148 |  | -              if (values.hasNext) {  | 
149 |  | -                values.next()  | 
150 |  | -              } else {  | 
151 |  | -                throw new EOFException("End of the file: mock deserializeStream")  | 
152 |  | -              }  | 
153 |  | -            }  | 
154 |  | - | 
155 |  | -            // NOTE: the readKey and readValue methods are called by asKeyValueIterator()  | 
156 |  | -            // which is wrapped in a NextIterator  | 
157 |  | -            override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]  | 
158 |  | - | 
159 |  | -            override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]  | 
160 |  | -          }  | 
161 |  | - | 
162 |  | -        override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T =  | 
163 |  | -          null.asInstanceOf[T]  | 
164 |  | - | 
165 |  | -        override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0)  | 
166 |  | - | 
167 |  | -        override def serializeStream(s: OutputStream): SerializationStream =  | 
168 |  | -          null.asInstanceOf[SerializationStream]  | 
169 |  | - | 
170 |  | -        override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T]  | 
171 |  | -      }  | 
172 |  | -    }))  | 
173 |  | - | 
174 |  | -    val mockBlockManager = {  | 
175 |  | -      // Create a block manager that isn't configured for compression, just returns input stream  | 
176 |  | -      val blockManager = mock(classOf[BlockManager])  | 
177 |  | -      when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]()))  | 
178 |  | -        .thenAnswer(new Answer[InputStream] {  | 
179 |  | -        override def answer(invocation: InvocationOnMock): InputStream = {  | 
180 |  | -          val blockId = invocation.getArguments()(0).asInstanceOf[BlockId]  | 
181 |  | -          val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream]  | 
182 |  | -          inputStream  | 
183 |  | -        }  | 
184 |  | -      })  | 
185 |  | -      blockManager  | 
186 |  | -    }  | 
187 |  | - | 
188 |  | -    val mockInputStream = mock(classOf[InputStream])  | 
189 |  | -    when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]()))  | 
190 |  | -      .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream)))  | 
191 |  | - | 
192 |  | -    val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep)  | 
193 |  | - | 
194 |  | -    val reader = new HashShuffleReader(shuffleHandle, 0, 1,  | 
195 |  | -      mockContext, mockBlockManager, mockShuffleFetcher)  | 
196 |  | - | 
197 |  | -    val values = reader.read()  | 
198 |  | -    // Verify that we're reading the correct values  | 
199 |  | -    var numValuesRead = 0  | 
200 |  | -    for (((key: Int, value: Int), i) <- values.zipWithIndex) {  | 
201 |  | -      assert(key == i * 2)  | 
202 |  | -      assert(value == i * 2 + 1)  | 
203 |  | -      numValuesRead += 1  | 
204 |  | -    }  | 
205 |  | -    // Verify that we read the correct number of values  | 
206 |  | -    assert(numKeyValuePairs == numValuesRead)  | 
207 |  | -    // Verify that our input stream was closed  | 
208 |  | -    verify(mockInputStream, times(1)).close()  | 
209 |  | -    // Verify that we collected metrics for each key/value pair  | 
210 |  | -    verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1)  | 
211 |  | -  }  | 
212 | 110 | }  | 
0 commit comments