|  | 
|  | 1 | +/* | 
|  | 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more | 
|  | 3 | + * contributor license agreements.  See the NOTICE file distributed with | 
|  | 4 | + * this work for additional information regarding copyright ownership. | 
|  | 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 | 
|  | 6 | + * (the "License"); you may not use this file except in compliance with | 
|  | 7 | + * the License.  You may obtain a copy of the License at | 
|  | 8 | + * | 
|  | 9 | + *    http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 10 | + * | 
|  | 11 | + * Unless required by applicable law or agreed to in writing, software | 
|  | 12 | + * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 14 | + * See the License for the specific language governing permissions and | 
|  | 15 | + * limitations under the License. | 
|  | 16 | + */ | 
|  | 17 | + | 
|  | 18 | +package org.apache.spark.shuffle.hash | 
|  | 19 | + | 
|  | 20 | +import java.io.{ByteArrayOutputStream, InputStream} | 
|  | 21 | +import java.nio.ByteBuffer | 
|  | 22 | + | 
|  | 23 | +import org.mockito.Matchers.{eq => meq, _} | 
|  | 24 | +import org.mockito.Mockito.{mock, when} | 
|  | 25 | +import org.mockito.invocation.InvocationOnMock | 
|  | 26 | +import org.mockito.stubbing.Answer | 
|  | 27 | + | 
|  | 28 | +import org.apache.spark._ | 
|  | 29 | +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} | 
|  | 30 | +import org.apache.spark.serializer.JavaSerializer | 
|  | 31 | +import org.apache.spark.shuffle.BaseShuffleHandle | 
|  | 32 | +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} | 
|  | 33 | + | 
|  | 34 | +/** | 
|  | 35 | + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. | 
|  | 36 | + * | 
|  | 37 | + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class | 
|  | 38 | + * is final (final classes cannot be spied on). | 
|  | 39 | + */ | 
|  | 40 | +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { | 
|  | 41 | +  var callsToRetain = 0 | 
|  | 42 | +  var callsToRelease = 0 | 
|  | 43 | + | 
|  | 44 | +  override def size() = underlyingBuffer.size() | 
|  | 45 | +  override def nioByteBuffer() = underlyingBuffer.nioByteBuffer() | 
|  | 46 | +  override def createInputStream() = underlyingBuffer.createInputStream() | 
|  | 47 | +  override def convertToNetty() = underlyingBuffer.convertToNetty() | 
|  | 48 | + | 
|  | 49 | +  override def retain(): ManagedBuffer = { | 
|  | 50 | +    callsToRetain += 1 | 
|  | 51 | +    underlyingBuffer.retain() | 
|  | 52 | +  } | 
|  | 53 | +  override def release(): ManagedBuffer = { | 
|  | 54 | +    callsToRelease += 1 | 
|  | 55 | +    underlyingBuffer.release() | 
|  | 56 | +  } | 
|  | 57 | +} | 
|  | 58 | + | 
|  | 59 | +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { | 
|  | 60 | + | 
|  | 61 | +  /** | 
|  | 62 | +   * This test makes sure that, when data is read from a HashShuffleReader, the underlying | 
|  | 63 | +   * ManagedBuffers that contain the data are eventually released. | 
|  | 64 | +   */ | 
|  | 65 | +  test("read() releases resources on completion") { | 
|  | 66 | +    val testConf = new SparkConf(false) | 
|  | 67 | +    // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the | 
|  | 68 | +    // shuffle code calls SparkEnv.get()). | 
|  | 69 | +    sc = new SparkContext("local", "test", testConf) | 
|  | 70 | + | 
|  | 71 | +    val reduceId = 15 | 
|  | 72 | +    val shuffleId = 22 | 
|  | 73 | +    val numMaps = 6 | 
|  | 74 | +    val keyValuePairsPerMap = 10 | 
|  | 75 | +    val serializer = new JavaSerializer(testConf) | 
|  | 76 | + | 
|  | 77 | +    // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we | 
|  | 78 | +    // can ensure retain() and release() are properly called. | 
|  | 79 | +    val blockManager = mock(classOf[BlockManager]) | 
|  | 80 | + | 
|  | 81 | +    // Create a return function to use for the mocked wrapForCompression method that just returns | 
|  | 82 | +    // the original input stream. | 
|  | 83 | +    val dummyCompressionFunction = new Answer[InputStream] { | 
|  | 84 | +      override def answer(invocation: InvocationOnMock) = | 
|  | 85 | +        invocation.getArguments()(1).asInstanceOf[InputStream] | 
|  | 86 | +    } | 
|  | 87 | + | 
|  | 88 | +    // Create a buffer with some randomly generated key-value pairs to use as the shuffle data | 
|  | 89 | +    // from each mappers (all mappers return the same shuffle data). | 
|  | 90 | +    val byteOutputStream = new ByteArrayOutputStream() | 
|  | 91 | +    val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) | 
|  | 92 | +    (0 until keyValuePairsPerMap).foreach { i => | 
|  | 93 | +      serializationStream.writeKey(i) | 
|  | 94 | +      serializationStream.writeValue(2*i) | 
|  | 95 | +    } | 
|  | 96 | + | 
|  | 97 | +    // Setup the mocked BlockManager to return RecordingManagedBuffers. | 
|  | 98 | +    val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) | 
|  | 99 | +    when(blockManager.blockManagerId).thenReturn(localBlockManagerId) | 
|  | 100 | +    val buffers = (0 until numMaps).map { mapId => | 
|  | 101 | +      // Create a ManagedBuffer with the shuffle data. | 
|  | 102 | +      val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) | 
|  | 103 | +      val managedBuffer = new RecordingManagedBuffer(nioBuffer) | 
|  | 104 | + | 
|  | 105 | +      // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to | 
|  | 106 | +      // fetch shuffle data. | 
|  | 107 | +      val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) | 
|  | 108 | +      when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) | 
|  | 109 | +      when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) | 
|  | 110 | +        .thenAnswer(dummyCompressionFunction) | 
|  | 111 | + | 
|  | 112 | +      managedBuffer | 
|  | 113 | +    } | 
|  | 114 | + | 
|  | 115 | +    // Make a mocked MapOutputTracker for the shuffle reader to use to determine what | 
|  | 116 | +    // shuffle data to read. | 
|  | 117 | +    val mapOutputTracker = mock(classOf[MapOutputTracker]) | 
|  | 118 | +    // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks | 
|  | 119 | +    // for the code to read data over the network. | 
|  | 120 | +    val statuses: Array[(BlockManagerId, Long)] = | 
|  | 121 | +      Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size())) | 
|  | 122 | +    when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) | 
|  | 123 | + | 
|  | 124 | +    // Create a mocked shuffle handle to pass into HashShuffleReader. | 
|  | 125 | +    val shuffleHandle = { | 
|  | 126 | +      val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) | 
|  | 127 | +      when(dependency.serializer).thenReturn(Some(serializer)) | 
|  | 128 | +      when(dependency.aggregator).thenReturn(None) | 
|  | 129 | +      when(dependency.keyOrdering).thenReturn(None) | 
|  | 130 | +      new BaseShuffleHandle(shuffleId, numMaps, dependency) | 
|  | 131 | +    } | 
|  | 132 | + | 
|  | 133 | +    val shuffleReader = new HashShuffleReader( | 
|  | 134 | +      shuffleHandle, | 
|  | 135 | +      reduceId, | 
|  | 136 | +      reduceId + 1, | 
|  | 137 | +      new TaskContextImpl(0, 0, 0, 0, null), | 
|  | 138 | +      blockManager, | 
|  | 139 | +      mapOutputTracker) | 
|  | 140 | + | 
|  | 141 | +    assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) | 
|  | 142 | + | 
|  | 143 | +    // Calling .length above will have exhausted the iterator; make sure that exhausting the | 
|  | 144 | +    // iterator caused retain and release to be called on each buffer. | 
|  | 145 | +    buffers.foreach { buffer => | 
|  | 146 | +      assert(buffer.callsToRetain === 1) | 
|  | 147 | +      assert(buffer.callsToRelease === 1) | 
|  | 148 | +    } | 
|  | 149 | +  } | 
|  | 150 | +} | 
0 commit comments