1717
1818package org .apache .spark .storage
1919
20+ import java .io .InputStream
2021import java .util .concurrent .LinkedBlockingQueue
2122
22- import scala .collection .mutable .{ArrayBuffer , HashSet , Queue }
23+ import scala .collection .mutable
24+ import scala .collection .mutable .ArrayBuffer
2325import scala .util .{Failure , Try }
2426
25- import org .apache .spark .{Logging , TaskContext }
26- import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
2727import org .apache .spark .network .buffer .ManagedBuffer
28- import org .apache .spark .serializer .{SerializerInstance , Serializer }
29- import org .apache .spark .util .{CompletionIterator , Utils }
28+ import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
29+ import org .apache .spark .util .Utils
30+ import org .apache .spark .{Logging , TaskContext }
3031
3132/**
3233 * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
3334 * manager. For remote blocks, it fetches them using the provided BlockTransferService.
3435 *
35- * This creates an iterator of (BlockID, values ) tuples so the caller can handle blocks in a
36- * pipelined fashion as they are received.
36+ * This creates an iterator of (BlockID, Try[InputStream] ) tuples so the caller can handle blocks
37+ * in a pipelined fashion as they are received.
3738 *
3839 * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
3940 * using too much memory.
@@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
4445 * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId ]].
4546 * For each block we also require the size (in bytes as a long field) in
4647 * order to throttle the memory usage.
47- * @param serializer serializer used to deserialize the data.
4848 * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
4949 */
5050private [spark]
@@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator(
5353 shuffleClient : ShuffleClient ,
5454 blockManager : BlockManager ,
5555 blocksByAddress : Seq [(BlockManagerId , Seq [(BlockId , Long )])],
56- serializer : Serializer ,
5756 maxBytesInFlight : Long )
58- extends Iterator [(BlockId , Try [Iterator [ Any ] ])] with Logging {
57+ extends Iterator [(BlockId , Try [InputStream ])] with Logging {
5958
6059 import ShuffleBlockFetcherIterator ._
6160
@@ -79,11 +78,11 @@ final class ShuffleBlockFetcherIterator(
7978 private [this ] val localBlocks = new ArrayBuffer [BlockId ]()
8079
8180 /** Remote blocks to fetch, excluding zero-sized blocks. */
82- private [this ] val remoteBlocks = new HashSet [BlockId ]()
81+ private [this ] val remoteBlocks = new mutable. HashSet [BlockId ]()
8382
8483 /**
8584 * A queue to hold our results. This turns the asynchronous model provided by
86- * [[BlockTransferService ]] into a synchronous model (iterator).
85+ * [[org.apache.spark.network. BlockTransferService ]] into a synchronous model (iterator).
8786 */
8887 private [this ] val results = new LinkedBlockingQueue [FetchResult ]
8988
@@ -97,14 +96,12 @@ final class ShuffleBlockFetcherIterator(
9796 * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
9897 * the number of bytes in flight is limited to maxBytesInFlight.
9998 */
100- private [this ] val fetchRequests = new Queue [FetchRequest ]
99+ private [this ] val fetchRequests = new mutable. Queue [FetchRequest ]
101100
102101 /** Current bytes in flight from our requests */
103102 private [this ] var bytesInFlight = 0L
104103
105- private [this ] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
106-
107- private [this ] val serializerInstance : SerializerInstance = serializer.newInstance()
104+ private [this ] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
108105
109106 /**
110107 * Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +111,23 @@ final class ShuffleBlockFetcherIterator(
114111
115112 initialize()
116113
117- /**
118- * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
119- */
120- private [this ] def cleanup () {
121- isZombie = true
114+ // Decrements the buffer reference count.
115+ // The currentResult is set to null to prevent releasing the buffer again on cleanup()
116+ private [storage] def releaseCurrentResultBuffer (): Unit = {
122117 // Release the current buffer if necessary
123118 currentResult match {
124119 case SuccessFetchResult (_, _, buf) => buf.release()
125120 case _ =>
126121 }
122+ currentResult = null
123+ }
127124
125+ /**
126+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
127+ */
128+ private [this ] def cleanup () {
129+ isZombie = true
130+ releaseCurrentResultBuffer()
128131 // Release buffers in the results queue
129132 val iter = results.iterator()
130133 while (iter.hasNext) {
@@ -272,7 +275,7 @@ final class ShuffleBlockFetcherIterator(
272275
273276 override def hasNext : Boolean = numBlocksProcessed < numBlocksToFetch
274277
275- override def next (): (BlockId , Try [Iterator [ Any ] ]) = {
278+ override def next (): (BlockId , Try [InputStream ]) = {
276279 numBlocksProcessed += 1
277280 val startFetchWait = System .currentTimeMillis()
278281 currentResult = results.take()
@@ -290,29 +293,51 @@ final class ShuffleBlockFetcherIterator(
290293 sendRequest(fetchRequests.dequeue())
291294 }
292295
293- val iteratorTry : Try [Iterator [ Any ] ] = result match {
296+ val iteratorTry : Try [InputStream ] = result match {
294297 case FailureFetchResult (_, e) =>
295298 Failure (e)
296299 case SuccessFetchResult (blockId, _, buf) =>
297300 // There is a chance that createInputStream can fail (e.g. fetching a local file that does
298301 // not exist, SPARK-4085). In that case, we should propagate the right exception so
299302 // the scheduler gets a FetchFailedException.
300- Try (buf.createInputStream()).map { is0 =>
301- val is = blockManager.wrapForCompression(blockId, is0)
302- val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
303- CompletionIterator [Any , Iterator [Any ]](iter, {
304- // Once the iterator is exhausted, release the buffer and set currentResult to null
305- // so we don't release it again in cleanup.
306- currentResult = null
307- buf.release()
308- })
303+ Try (buf.createInputStream()).map { inputStream =>
304+ new WrappedInputStream (inputStream, this )
309305 }
310306 }
311307
312308 (result.blockId, iteratorTry)
313309 }
314310}
315311
312+ // Helper class that ensures a ManagerBuffer is released upon InputStream.close()
313+ private class WrappedInputStream (delegate : InputStream , iterator : ShuffleBlockFetcherIterator )
314+ extends InputStream {
315+ private var closed = false
316+
317+ override def read (): Int = delegate.read()
318+
319+ override def close (): Unit = {
320+ if (! closed) {
321+ delegate.close()
322+ iterator.releaseCurrentResultBuffer()
323+ closed = true
324+ }
325+ }
326+
327+ override def available (): Int = delegate.available()
328+
329+ override def mark (readlimit : Int ): Unit = delegate.mark(readlimit)
330+
331+ override def skip (n : Long ): Long = delegate.skip(n)
332+
333+ override def markSupported (): Boolean = delegate.markSupported()
334+
335+ override def read (b : Array [Byte ]): Int = delegate.read(b)
336+
337+ override def read (b : Array [Byte ], off : Int , len : Int ): Int = delegate.read(b, off, len)
338+
339+ override def reset (): Unit = delegate.reset()
340+ }
316341
317342private [storage]
318343object ShuffleBlockFetcherIterator {
0 commit comments