From 19135f298e215ae11f4c8fd3b8c51147fd8bcc46 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 27 May 2015 16:16:19 -0700 Subject: [PATCH 01/19] [SPARK-7884] Allow Spark shuffle APIs to be more customizable This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process. The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, Try[InputStream]). Deserialization of records is now handled in the ShuffleReader.read() method. This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are deserialized. --- .../hash/BlockStoreShuffleFetcher.scala | 35 +++----- .../shuffle/hash/HashShuffleReader.scala | 27 +++++- .../storage/ShuffleBlockFetcherIterator.scala | 89 ++++++++++++------- .../ShuffleBlockFetcherIteratorSuite.scala | 39 ++++---- 4 files changed, 119 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 597d46a3d2223..9a15f9bab834a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,23 +17,22 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.{Failure, Success, Try} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, - context: TaskContext, - serializer: Serializer) - : Iterator[T] = + context: TaskContext) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager @@ -53,12 +52,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[InputStream])) : (BlockId, InputStream) = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -78,21 +77,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { SparkEnv.get.blockManager.shuffleClient, blockManager, blocksByAddress, - serializer, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) + val itr = blockFetcherItr.map(unpackBlock) - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } + CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, { + context.taskMetrics().updateShuffleReadMetrics() + }) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..0f315b85bfca6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -33,11 +33,34 @@ private[spark] class HashShuffleReader[K, C]( "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency + private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIterator = wrappedStreams.flatMap { wrappedStream => + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update read metrics for each record materialized + val iter = new InterruptibleIterator[Any](context, recordIterator) { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + override def next(): Any = { + readMetrics.incRecordsRead(1) + delegate.next() + } + }.asInstanceOf[Iterator[Nothing]] val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..3758a758943d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,24 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Try} -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, TaskContext} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -79,11 +78,11 @@ final class ShuffleBlockFetcherIterator( private[this] val localBlocks = new ArrayBuffer[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ - private[this] val remoteBlocks = new HashSet[BlockId]() + private[this] val remoteBlocks = new mutable.HashSet[BlockId]() /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -97,14 +96,12 @@ final class ShuffleBlockFetcherIterator( * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that * the number of bytes in flight is limited to maxBytesInFlight. */ - private[this] val fetchRequests = new Queue[FetchRequest] + private[this] val fetchRequests = new mutable.Queue[FetchRequest] /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +111,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +275,7 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +293,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new WrappedInputStream(inputStream, this) } } @@ -313,6 +309,35 @@ final class ShuffleBlockFetcherIterator( } } +// Helper class that ensures a ManagerBuffer is released upon InputStream.close() +private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad8585..60e6840cb00bd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,21 +17,22 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer +import org.apache.spark.{SparkFunSuite, TaskContextImpl} + class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Some of the tests are quite tricky because we are testing the cleanup behavior @@ -57,7 +58,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = { + mock(classOf[InputStream]) + } + }) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -92,7 +102,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -104,10 +113,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(subIterator.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + val wrappedInputStream = new WrappedInputStream(mock(classOf[InputStream]), iterator) verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() } @@ -125,10 +137,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -159,11 +170,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.get.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator @@ -222,7 +232,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure From b70c945220dacb81a1343f3b0ed4d2a13c8cc76e Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 2 Jun 2015 15:09:37 -0700 Subject: [PATCH 02/19] Make BlockStoreShuffleFetcher visible to shuffle package --- .../apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 9a15f9bab834a..e93d0ba950266 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -27,7 +27,7 @@ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator -private[hash] object BlockStoreShuffleFetcher extends Logging { +private[shuffle] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, From 208b7a51c41bead773ec6bc8d6fefdb7e127ba76 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Mon, 8 Jun 2015 16:02:51 -0700 Subject: [PATCH 03/19] Small code style changes --- .../org/apache/spark/shuffle/hash/HashShuffleReader.scala | 2 +- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 0f315b85bfca6..ea8aa346cd9f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -38,7 +38,7 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context) + handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration val wrappedStreams = blockStreams.map { case (blockId, inputStream) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3758a758943d4..771194225af2f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,7 +21,7 @@ import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.{Failure, Try} import org.apache.spark.network.buffer.ManagedBuffer @@ -78,7 +78,7 @@ final class ShuffleBlockFetcherIterator( private[this] val localBlocks = new ArrayBuffer[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ - private[this] val remoteBlocks = new mutable.HashSet[BlockId]() + private[this] val remoteBlocks = new HashSet[BlockId]() /** * A queue to hold our results. This turns the asynchronous model provided by From 7c8f73e8cc497f023beae93b0bebc536d50a51cb Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 9 Jun 2015 10:15:51 -0700 Subject: [PATCH 04/19] Close Block InputStream immediately after all records are read --- .../apache/spark/shuffle/hash/HashShuffleReader.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index ea8aa346cd9f3..d92a2a98a3588 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} @@ -38,7 +39,7 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context) + handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration val wrappedStreams = blockStreams.map { case (blockId, inputStream) => @@ -50,7 +51,11 @@ private[spark] class HashShuffleReader[K, C]( // Create a key/value iterator for each stream val recordIterator = wrappedStreams.flatMap { wrappedStream => - serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, { + // Close the stream once all the records have been read from it + wrappedStream.close() + }) } // Update read metrics for each record materialized From 01e87211a00639369f8f4269a68745a603281293 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 9 Jun 2015 11:29:14 -0700 Subject: [PATCH 05/19] Explicitly cast iterator in branches for type clarity --- .../spark/shuffle/hash/HashShuffleReader.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index d92a2a98a3588..e0b8d58af490d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -65,13 +65,21 @@ private[spark] class HashShuffleReader[K, C]( readMetrics.incRecordsRead(1) delegate.next() } - }.asInstanceOf[Iterator[Nothing]] + } val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K,C)]] + new InterruptibleIterator(context, + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = iter.asInstanceOf[Iterator[(K,Nothing)]] + new InterruptibleIterator(context, + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") From 7e8e0fed6fba3d102f045f21e652d0a4376ac115 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 9 Jun 2015 13:12:21 -0700 Subject: [PATCH 06/19] Minor Scala style fixes --- .../spark/shuffle/hash/HashShuffleReader.scala | 12 +++++++----- .../spark/storage/ShuffleBlockFetcherIterator.scala | 11 ++++++++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index e0b8d58af490d..4efa9b0fc1871 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -53,15 +53,17 @@ private[spark] class HashShuffleReader[K, C]( val recordIterator = wrappedStreams.flatMap { wrappedStream => val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, { - // Close the stream once all the records have been read from it + // Close the stream once all the records have been read from it to free underlying + // ManagedBuffer as soon as possible. Note that in case of task failure, the task's + // TaskCompletionListener will make sure this is released. wrappedStream.close() }) } // Update read metrics for each record materialized - val iter = new InterruptibleIterator[Any](context, recordIterator) { + val iter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): Any = { + override def next(): (Any, Any) = { readMetrics.incRecordsRead(1) delegate.next() } @@ -70,14 +72,14 @@ private[spark] class HashShuffleReader[K, C]( val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined - val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K,C)]] + val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K, C)]] new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)) } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - val keyValuesIterator = iter.asInstanceOf[Iterator[(K,Nothing)]] + val keyValuesIterator = iter.asInstanceOf[Iterator[(K, Nothing)]] new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 771194225af2f..68f6b47fffc38 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,8 +20,7 @@ package org.apache.spark.storage import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.network.buffer.ManagedBuffer @@ -96,7 +95,7 @@ final class ShuffleBlockFetcherIterator( * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that * the number of bytes in flight is limited to maxBytesInFlight. */ - private[this] val fetchRequests = new mutable.Queue[FetchRequest] + private[this] val fetchRequests = new Queue[FetchRequest] /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L @@ -275,6 +274,12 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() From f93841e9c809da47ea76d876e6f5b70bbe9062ff Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 9 Jun 2015 16:54:30 -0700 Subject: [PATCH 07/19] Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher. This commit also includes Scala style cleanup. --- .../hash/BlockStoreShuffleFetcher.scala | 7 +---- .../shuffle/hash/HashShuffleReader.scala | 28 +++++++++---------- .../storage/ShuffleBlockFetcherIterator.scala | 10 ++++--- .../ShuffleBlockFetcherIteratorSuite.scala | 2 +- 4 files changed, 21 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index e93d0ba950266..16d206a6f8043 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -25,7 +25,6 @@ import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator private[shuffle] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( @@ -80,10 +79,6 @@ private[shuffle] object BlockStoreShuffleFetcher extends Logging { // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.map(unpackBlock) - - CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, { - context.taskMetrics().updateShuffleReadMetrics() - }) + blockFetcherItr.map(unpackBlock) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 4efa9b0fc1871..40e54ca0a3ab2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,11 +17,11 @@ package org.apache.spark.shuffle.hash +import org.apache.spark.{SparkEnv, TaskContext, InterruptibleIterator} import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -51,24 +51,22 @@ private[spark] class HashShuffleReader[K, C]( // Create a key/value iterator for each stream val recordIterator = wrappedStreams.flatMap { wrappedStream => - val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator - CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, { - // Close the stream once all the records have been read from it to free underlying - // ManagedBuffer as soon as possible. Note that in case of task failure, the task's - // TaskCompletionListener will make sure this is released. - wrappedStream.close() - }) + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() // Update read metrics for each record materialized - val iter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): (Any, Any) = { - readMetrics.incRecordsRead(1) - delegate.next() - } + val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { + override def next(): (Any, Any) = { + readMetrics.incRecordsRead(1) + delegate.next() + } } + val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, { + context.taskMetrics().updateShuffleReadMetrics() + }) + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 68f6b47fffc38..a1376c8f4e484 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,10 +23,10 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.util.Utils -import org.apache.spark.{Logging, TaskContext} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -306,7 +306,7 @@ final class ShuffleBlockFetcherIterator( // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { inputStream => - new WrappedInputStream(inputStream, this) + new BufferReleasingInputStream(inputStream, this) } } @@ -314,8 +314,10 @@ final class ShuffleBlockFetcherIterator( } } -// Helper class that ensures a ManagerBuffer is released upon InputStream.close() -private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator) +/** Helper class that ensures a ManagerBuffer is released upon InputStream.close() */ +private class BufferReleasingInputStream( + delegate: InputStream, + iterator: ShuffleBlockFetcherIterator) extends InputStream { private var closed = false diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 60e6840cb00bd..f7dc651e6d5d0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -115,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) - val wrappedInputStream = new WrappedInputStream(mock(classOf[InputStream]), iterator) + val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator) verify(mockBuf, times(0)).release() wrappedInputStream.close() verify(mockBuf, times(1)).release() From 28f8085c769612af65e1d46bc11b73c22a42169c Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 9 Jun 2015 17:01:39 -0700 Subject: [PATCH 08/19] Small import nit --- .../scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 40e54ca0a3ab2..6d1649f7925f3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.hash import org.apache.spark.{SparkEnv, TaskContext, InterruptibleIterator} import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle} +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter From 7eedd1daac8f748a7dd7d1e4b94758b5c980c956 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 10 Jun 2015 10:52:34 -0700 Subject: [PATCH 09/19] Small Scala import cleanup --- .../scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 6d1649f7925f3..e6df70dbdb9f2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{SparkEnv, TaskContext, InterruptibleIterator} +import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.util.CompletionIterator From 5c3040521358ea3c76cc20ca2de4966ff69845f0 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 10 Jun 2015 15:28:44 -0700 Subject: [PATCH 10/19] Return visibility of BlockStoreShuffleFetcher to private[hash] --- .../apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 16d206a6f8043..c7420bd9a4cdc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -private[shuffle] object BlockStoreShuffleFetcher extends Logging { +private[hash] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, From 4abb855fd22b560c8539b8cbbdb205349d7fb57c Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 10 Jun 2015 16:37:59 -0700 Subject: [PATCH 11/19] Consolidate metric code. Make it clear why InterrubtibleIterator is needed. There is also some Scala style cleanup in this commit. --- .../shuffle/hash/HashShuffleReader.scala | 31 +++++++++---------- .../ShuffleBlockFetcherIteratorSuite.scala | 25 ++++++--------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index e6df70dbdb9f2..9a60d94af0cc5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -50,42 +50,39 @@ private[spark] class HashShuffleReader[K, C]( val serializerInstance = ser.newInstance() // Create a key/value iterator for each stream - val recordIterator = wrappedStreams.flatMap { wrappedStream => + val recordIter = wrappedStreams.flatMap { wrappedStream => serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } + // Update the context task metrics for each record read. val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - // Update read metrics for each record materialized - val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { - override def next(): (Any, Any) = { + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { readMetrics.incRecordsRead(1) - delegate.next() - } - } + record + }), + context.taskMetrics().updateShuffleReadMetrics()) - val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, { - context.taskMetrics().updateShuffleReadMetrics() - }) + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined - val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K, C)]] - new InterruptibleIterator(context, - dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)) + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { // We don't know the value type, but also don't care -- the dependency *should* // have made sure its compatible w/ this aggregator, which will convert the value // type to the combined type C - val keyValuesIterator = iter.asInstanceOf[Iterator[(K, Nothing)]] - new InterruptibleIterator(context, - dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)) + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) } // Sort the output if there is a sort ordering defined. diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index f7dc651e6d5d0..89f6713946b4e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -28,10 +28,10 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.{SparkFunSuite, TaskContextImpl} class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { @@ -61,11 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - mock(classOf[InputStream]) - } - }) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) mockManagedBuffer } @@ -76,9 +72,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -86,9 +82,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -109,13 +104,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, + val (blockId, inputStream) = iterator.next() + assert(inputStream.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) - val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator) + val wrappedInputStream = new BufferReleasingInputStream(inputStream.get, iterator) verify(mockBuf, times(0)).release() wrappedInputStream.close() verify(mockBuf, times(1)).release() From f45848936d7df34d02162ed589da0d085112cb48 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Thu, 11 Jun 2015 09:59:47 -0700 Subject: [PATCH 12/19] Remove unnecessary map() on return Iterator --- .../org/apache/spark/shuffle/hash/HashShuffleReader.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 9a60d94af0cc5..23bae9223bb7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -80,9 +80,7 @@ private[spark] class HashShuffleReader[K, C]( } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. From 7429a985cef1c3530fb147f68eabf12aae613a4a Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Fri, 12 Jun 2015 12:13:36 -0700 Subject: [PATCH 13/19] Update tests to check that BufferReleasingStream is closing delegate InputStream --- .../org/apache/spark/shuffle/hash/HashShuffleReader.scala | 3 +++ .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 7 +++++-- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 6 +++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 23bae9223bb7c..ca6eddf8d5c12 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -51,6 +51,9 @@ private[spark] class HashShuffleReader[K, C]( // Create a key/value iterator for each stream val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a1376c8f4e484..78361f2df6d3a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -314,9 +314,12 @@ final class ShuffleBlockFetcherIterator( } } -/** Helper class that ensures a ManagerBuffer is released upon InputStream.close() */ +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + * Note: the delegate parameter is private[storage] to make it available to tests. + */ private class BufferReleasingInputStream( - delegate: InputStream, + private[storage] val delegate: InputStream, iterator: ShuffleBlockFetcherIterator) extends InputStream { private var closed = false diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 89f6713946b4e..4657caf332c5c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -110,12 +110,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) - val wrappedInputStream = new BufferReleasingInputStream(inputStream.get, iterator) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() + verify(wrappedInputStream.delegate, times(0)).close() wrappedInputStream.close() verify(mockBuf, times(1)).release() + verify(wrappedInputStream.delegate, times(1)).close() wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() + verify(wrappedInputStream.delegate, times(1)).close() } // 3 local blocks, and 2 remote blocks From 4ea17129098696eaf59967e335eda923f3e8d341 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Fri, 12 Jun 2015 12:41:10 -0700 Subject: [PATCH 14/19] Small code cleanup for readability --- .../hash/BlockStoreShuffleFetcher.scala | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index c7420bd9a4cdc..0635b98742096 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -51,7 +51,16 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[InputStream])) : (BlockId, InputStream) = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { @@ -70,15 +79,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - - blockFetcherItr.map(unpackBlock) } } From a011bfabddc0ca6705a6d59a9112cd4216d0241c Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Thu, 18 Jun 2015 14:51:19 -0700 Subject: [PATCH 15/19] Use PrivateMethodTester on check that delegate stream is closed --- .../spark/storage/ShuffleBlockFetcherIterator.scala | 6 +++--- .../storage/ShuffleBlockFetcherIteratorSuite.scala | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 78361f2df6d3a..6a9771777776d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -319,10 +319,10 @@ final class ShuffleBlockFetcherIterator( * Note: the delegate parameter is private[storage] to make it available to tests. */ private class BufferReleasingInputStream( - private[storage] val delegate: InputStream, - iterator: ShuffleBlockFetcherIterator) + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) extends InputStream { - private var closed = false + private[this] var closed = false override def read(): Int = delegate.read() diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 4657caf332c5c..9ced4148d7206 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -27,6 +27,7 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ @@ -34,7 +35,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -113,13 +114,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() - verify(wrappedInputStream.delegate, times(0)).close() + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() wrappedInputStream.close() verify(mockBuf, times(1)).release() - verify(wrappedInputStream.delegate, times(1)).close() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() - verify(wrappedInputStream.delegate, times(1)).close() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } // 3 local blocks, and 2 remote blocks From f98a1b9503ed8fb4c5fc3e9033a744c254237c45 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Mon, 22 Jun 2015 15:00:28 -0700 Subject: [PATCH 16/19] Add test to ensure HashShuffleReader is freeing resources --- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../shuffle/hash/HashShuffleReader.scala | 8 +- .../hash/HashShuffleManagerSuite.scala | 114 +++++++++++++++++- 3 files changed, 115 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0635b98742096..aefb2f5685537 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -26,7 +26,8 @@ import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -private[hash] object BlockStoreShuffleFetcher extends Logging { +private[hash] class BlockStoreShuffleFetcher extends Logging { + def fetchBlockStreams( shuffleId: Int, reduceId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index ca6eddf8d5c12..b868f32f5cce1 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.hash +import org.apache.spark.storage.BlockManager import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -27,18 +28,19 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency - private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams( handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 491dc3659e184..53b2b89a5e641 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -17,16 +17,22 @@ package org.apache.spark.shuffle.hash -import java.io.{File, FileWriter} +import java.io._ +import java.nio.ByteBuffer import scala.language.reflectiveCalls -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} +import org.apache.spark.serializer._ +import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) @@ -107,4 +113,100 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until numBytes) writer.write(i) writer.close() } + + test("HashShuffleReader.read() releases resources and tracks metrics") { + val shuffleId = 1 + val numMaps = 2 + val numKeyValuePairs = 10 + + val mockContext = mock(classOf[TaskContext]) + + val mockTaskMetrics = mock(classOf[TaskMetrics]) + val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) + when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) + when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) + + val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) + + val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) + when(mockDep.keyOrdering).thenReturn(None) + when(mockDep.aggregator).thenReturn(None) + when(mockDep.serializer).thenReturn(Some(new Serializer { + override def newInstance(): SerializerInstance = new SerializerInstance { + + override def deserializeStream(s: InputStream): DeserializationStream = + new DeserializationStream { + override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] + + override def close(): Unit = s.close() + + private val values = { + for (i <- 0 to numKeyValuePairs * 2) yield i + }.iterator + + private def getValueOrEOF(): Int = { + if (values.hasNext) { + values.next() + } else { + throw new EOFException("End of the file: mock deserializeStream") + } + } + + // NOTE: the readKey and readValue methods are called by asKeyValueIterator() + // which is wrapped in a NextIterator + override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] + + override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] + } + + override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = + null.asInstanceOf[T] + + override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) + + override def serializeStream(s: OutputStream): SerializationStream = + null.asInstanceOf[SerializationStream] + + override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] + } + })) + + val mockBlockManager = { + // Create a block manager that isn't configured for compression, just returns input stream + val blockManager = mock(classOf[BlockManager]) + when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) + .thenAnswer(new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = { + val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] + val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] + inputStream + } + }) + blockManager + } + + val mockInputStream = mock(classOf[InputStream]) + when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) + .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) + + val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) + + val reader = new HashShuffleReader(shuffleHandle, 0, 1, + mockContext, mockBlockManager, mockShuffleFetcher) + + val values = reader.read() + // Verify that we're reading the correct values + var numValuesRead = 0 + for (((key: Int, value: Int), i) <- values.zipWithIndex) { + assert(key == i * 2) + assert(value == i * 2 + 1) + numValuesRead += 1 + } + // Verify that we read the correct number of values + assert(numKeyValuePairs == numValuesRead) + // Verify that our input stream was closed + verify(mockInputStream, times(1)).close() + // Verify that we collected metrics for each key/value pair + verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) + } } From 5186da0697d5a1efe4229b7b0a224979ce7f2bc7 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 23 Jun 2015 11:34:58 -0700 Subject: [PATCH 17/19] Revert "Add test to ensure HashShuffleReader is freeing resources" This reverts commit f98a1b9503ed8fb4c5fc3e9033a744c254237c45. --- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../shuffle/hash/HashShuffleReader.scala | 8 +- .../hash/HashShuffleManagerSuite.scala | 114 +----------------- 3 files changed, 10 insertions(+), 115 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index aefb2f5685537..0635b98742096 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -26,8 +26,7 @@ import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -private[hash] class BlockStoreShuffleFetcher extends Logging { - +private[hash] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index b868f32f5cce1..ca6eddf8d5c12 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.storage.BlockManager import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -28,19 +27,18 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext, - blockManager: BlockManager = SparkEnv.get.blockManager, - blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher) + context: TaskContext) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency + private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams( + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 53b2b89a5e641..491dc3659e184 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -17,22 +17,16 @@ package org.apache.spark.shuffle.hash -import java.io._ -import java.nio.ByteBuffer +import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer - -import org.apache.spark._ -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer._ -import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FileShuffleBlockResolver +import org.apache.spark.storage.{ShuffleBlockId, FileSegment} class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) @@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until numBytes) writer.write(i) writer.close() } - - test("HashShuffleReader.read() releases resources and tracks metrics") { - val shuffleId = 1 - val numMaps = 2 - val numKeyValuePairs = 10 - - val mockContext = mock(classOf[TaskContext]) - - val mockTaskMetrics = mock(classOf[TaskMetrics]) - val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) - when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) - when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) - - val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) - - val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) - when(mockDep.keyOrdering).thenReturn(None) - when(mockDep.aggregator).thenReturn(None) - when(mockDep.serializer).thenReturn(Some(new Serializer { - override def newInstance(): SerializerInstance = new SerializerInstance { - - override def deserializeStream(s: InputStream): DeserializationStream = - new DeserializationStream { - override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] - - override def close(): Unit = s.close() - - private val values = { - for (i <- 0 to numKeyValuePairs * 2) yield i - }.iterator - - private def getValueOrEOF(): Int = { - if (values.hasNext) { - values.next() - } else { - throw new EOFException("End of the file: mock deserializeStream") - } - } - - // NOTE: the readKey and readValue methods are called by asKeyValueIterator() - // which is wrapped in a NextIterator - override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] - - override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] - } - - override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = - null.asInstanceOf[T] - - override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) - - override def serializeStream(s: OutputStream): SerializationStream = - null.asInstanceOf[SerializationStream] - - override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] - } - })) - - val mockBlockManager = { - // Create a block manager that isn't configured for compression, just returns input stream - val blockManager = mock(classOf[BlockManager]) - when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) - .thenAnswer(new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] - val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] - inputStream - } - }) - blockManager - } - - val mockInputStream = mock(classOf[InputStream]) - when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) - .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) - - val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) - - val reader = new HashShuffleReader(shuffleHandle, 0, 1, - mockContext, mockBlockManager, mockShuffleFetcher) - - val values = reader.read() - // Verify that we're reading the correct values - var numValuesRead = 0 - for (((key: Int, value: Int), i) <- values.zipWithIndex) { - assert(key == i * 2) - assert(value == i * 2 + 1) - numValuesRead += 1 - } - // Verify that we read the correct number of values - assert(numKeyValuePairs == numValuesRead) - // Verify that our input stream was closed - verify(mockInputStream, times(1)).close() - // Verify that we collected metrics for each key/value pair - verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) - } } From 290f1eb356024fb58a209e9fc6c8800bfc0e6688 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 23 Jun 2015 15:41:06 -0700 Subject: [PATCH 18/19] Added test for HashShuffleReader.read() --- .../hash/BlockStoreShuffleFetcher.scala | 14 +- .../shuffle/hash/HashShuffleReader.scala | 10 +- .../shuffle/hash/HashShuffleReaderSuite.scala | 150 ++++++++++++++++++ 3 files changed, 164 insertions(+), 10 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0635b98742096..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -20,24 +20,26 @@ package org.apache.spark.shuffle.hash import java.io.InputStream import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Success} import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,7 +55,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, - SparkEnv.get.blockManager.shuffleClient, + blockManager.shuffleClient, blockManager, blocksByAddress, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index ca6eddf8d5c12..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -27,19 +28,20 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency - private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context) + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) // Wrap the streams for compression based on configuration val wrappedStreams = blockStreams.map { case (blockId, inputStream) => diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 0000000000000..0add85c6377dc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size() = underlyingBuffer.size() + override def nioByteBuffer() = underlyingBuffer.nioByteBuffer() + override def createInputStream() = underlyingBuffer.createInputStream() + override def convertToNetty() = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock) = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val statuses: Array[(BlockManagerId, Long)] = + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size())) + when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + new TaskContextImpl(0, 0, 0, 0, null), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} From 8b0632ca78492f80e26d4b3493296b3b04b55866 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Tue, 23 Jun 2015 16:55:11 -0700 Subject: [PATCH 19/19] Minor Scala style fixes --- .../spark/shuffle/hash/HashShuffleReaderSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index 0add85c6377dc..28ca68698e3dc 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -41,10 +41,10 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed var callsToRetain = 0 var callsToRelease = 0 - override def size() = underlyingBuffer.size() - override def nioByteBuffer() = underlyingBuffer.nioByteBuffer() - override def createInputStream() = underlyingBuffer.createInputStream() - override def convertToNetty() = underlyingBuffer.convertToNetty() + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() override def retain(): ManagedBuffer = { callsToRetain += 1 @@ -81,7 +81,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { // Create a return function to use for the mocked wrapForCompression method that just returns // the original input stream. val dummyCompressionFunction = new Answer[InputStream] { - override def answer(invocation: InvocationOnMock) = + override def answer(invocation: InvocationOnMock): InputStream = invocation.getArguments()(1).asInstanceOf[InputStream] } @@ -118,7 +118,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks // for the code to read data over the network. val statuses: Array[(BlockManagerId, Long)] = - Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size())) + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) // Create a mocked shuffle handle to pass into HashShuffleReader.