|  | 
| 17 | 17 | 
 | 
| 18 | 18 | package org.apache.spark.shuffle.hash | 
| 19 | 19 | 
 | 
| 20 |  | -import java.io.{File, FileWriter} | 
|  | 20 | +import java.io._ | 
|  | 21 | +import java.nio.ByteBuffer | 
| 21 | 22 | 
 | 
| 22 | 23 | import scala.language.reflectiveCalls | 
| 23 | 24 | 
 | 
| 24 |  | -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} | 
| 25 |  | -import org.apache.spark.executor.ShuffleWriteMetrics | 
|  | 25 | +import org.mockito.Matchers.any | 
|  | 26 | +import org.mockito.Mockito._ | 
|  | 27 | +import org.mockito.invocation.InvocationOnMock | 
|  | 28 | +import org.mockito.stubbing.Answer | 
|  | 29 | + | 
|  | 30 | +import org.apache.spark._ | 
|  | 31 | +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} | 
| 26 | 32 | import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} | 
| 27 |  | -import org.apache.spark.serializer.JavaSerializer | 
| 28 |  | -import org.apache.spark.shuffle.FileShuffleBlockResolver | 
| 29 |  | -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} | 
|  | 33 | +import org.apache.spark.serializer._ | 
|  | 34 | +import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} | 
|  | 35 | +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} | 
| 30 | 36 | 
 | 
| 31 | 37 | class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { | 
| 32 | 38 |   private val testConf = new SparkConf(false) | 
| @@ -107,4 +113,100 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { | 
| 107 | 113 |     for (i <- 0 until numBytes) writer.write(i) | 
| 108 | 114 |     writer.close() | 
| 109 | 115 |   } | 
|  | 116 | + | 
|  | 117 | +  test("HashShuffleReader.read() releases resources and tracks metrics") { | 
|  | 118 | +    val shuffleId = 1 | 
|  | 119 | +    val numMaps = 2 | 
|  | 120 | +    val numKeyValuePairs = 10 | 
|  | 121 | + | 
|  | 122 | +    val mockContext = mock(classOf[TaskContext]) | 
|  | 123 | + | 
|  | 124 | +    val mockTaskMetrics = mock(classOf[TaskMetrics]) | 
|  | 125 | +    val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) | 
|  | 126 | +    when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) | 
|  | 127 | +    when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) | 
|  | 128 | + | 
|  | 129 | +    val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) | 
|  | 130 | + | 
|  | 131 | +    val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) | 
|  | 132 | +    when(mockDep.keyOrdering).thenReturn(None) | 
|  | 133 | +    when(mockDep.aggregator).thenReturn(None) | 
|  | 134 | +    when(mockDep.serializer).thenReturn(Some(new Serializer { | 
|  | 135 | +      override def newInstance(): SerializerInstance = new SerializerInstance { | 
|  | 136 | + | 
|  | 137 | +        override def deserializeStream(s: InputStream): DeserializationStream = | 
|  | 138 | +          new DeserializationStream { | 
|  | 139 | +            override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] | 
|  | 140 | + | 
|  | 141 | +            override def close(): Unit = s.close() | 
|  | 142 | + | 
|  | 143 | +            private val values = { | 
|  | 144 | +              for (i <- 0 to numKeyValuePairs * 2) yield i | 
|  | 145 | +            }.iterator | 
|  | 146 | + | 
|  | 147 | +            private def getValueOrEOF(): Int = { | 
|  | 148 | +              if (values.hasNext) { | 
|  | 149 | +                values.next() | 
|  | 150 | +              } else { | 
|  | 151 | +                throw new EOFException("End of the file: mock deserializeStream") | 
|  | 152 | +              } | 
|  | 153 | +            } | 
|  | 154 | + | 
|  | 155 | +            // NOTE: the readKey and readValue methods are called by asKeyValueIterator() | 
|  | 156 | +            // which is wrapped in a NextIterator | 
|  | 157 | +            override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] | 
|  | 158 | + | 
|  | 159 | +            override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] | 
|  | 160 | +          } | 
|  | 161 | + | 
|  | 162 | +        override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = | 
|  | 163 | +          null.asInstanceOf[T] | 
|  | 164 | + | 
|  | 165 | +        override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) | 
|  | 166 | + | 
|  | 167 | +        override def serializeStream(s: OutputStream): SerializationStream = | 
|  | 168 | +          null.asInstanceOf[SerializationStream] | 
|  | 169 | + | 
|  | 170 | +        override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] | 
|  | 171 | +      } | 
|  | 172 | +    })) | 
|  | 173 | + | 
|  | 174 | +    val mockBlockManager = { | 
|  | 175 | +      // Create a block manager that isn't configured for compression, just returns input stream | 
|  | 176 | +      val blockManager = mock(classOf[BlockManager]) | 
|  | 177 | +      when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) | 
|  | 178 | +        .thenAnswer(new Answer[InputStream] { | 
|  | 179 | +        override def answer(invocation: InvocationOnMock): InputStream = { | 
|  | 180 | +          val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] | 
|  | 181 | +          val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] | 
|  | 182 | +          inputStream | 
|  | 183 | +        } | 
|  | 184 | +      }) | 
|  | 185 | +      blockManager | 
|  | 186 | +    } | 
|  | 187 | + | 
|  | 188 | +    val mockInputStream = mock(classOf[InputStream]) | 
|  | 189 | +    when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) | 
|  | 190 | +      .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) | 
|  | 191 | + | 
|  | 192 | +    val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) | 
|  | 193 | + | 
|  | 194 | +    val reader = new HashShuffleReader(shuffleHandle, 0, 1, | 
|  | 195 | +      mockContext, mockBlockManager, mockShuffleFetcher) | 
|  | 196 | + | 
|  | 197 | +    val values = reader.read() | 
|  | 198 | +    // Verify that we're reading the correct values | 
|  | 199 | +    var numValuesRead = 0 | 
|  | 200 | +    for (((key: Int, value: Int), i) <- values.zipWithIndex) { | 
|  | 201 | +      assert(key == i * 2) | 
|  | 202 | +      assert(value == i * 2 + 1) | 
|  | 203 | +      numValuesRead += 1 | 
|  | 204 | +    } | 
|  | 205 | +    // Verify that we read the correct number of values | 
|  | 206 | +    assert(numKeyValuePairs == numValuesRead) | 
|  | 207 | +    // Verify that our input stream was closed | 
|  | 208 | +    verify(mockInputStream, times(1)).close() | 
|  | 209 | +    // Verify that we collected metrics for each key/value pair | 
|  | 210 | +    verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) | 
|  | 211 | +  } | 
| 110 | 212 | } | 
0 commit comments