@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash
1919
2020import org .apache .spark .serializer .Serializer
2121import org .apache .spark .shuffle .{BaseShuffleHandle , ShuffleReader }
22+ import org .apache .spark .util .CompletionIterator
2223import org .apache .spark .util .collection .ExternalSorter
2324import org .apache .spark .{InterruptibleIterator , SparkEnv , TaskContext }
2425
@@ -38,7 +39,7 @@ private[spark] class HashShuffleReader[K, C](
3839 /** Read the combined key-values for this reduce task */
3940 override def read (): Iterator [Product2 [K , C ]] = {
4041 val blockStreams = BlockStoreShuffleFetcher .fetchBlockStreams(
41- handle.shuffleId, startPartition, context)
42+ handle.shuffleId, startPartition, context)
4243
4344 // Wrap the streams for compression based on configuration
4445 val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
@@ -50,7 +51,11 @@ private[spark] class HashShuffleReader[K, C](
5051
5152 // Create a key/value iterator for each stream
5253 val recordIterator = wrappedStreams.flatMap { wrappedStream =>
53- serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
54+ val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
55+ CompletionIterator [(Any , Any ), Iterator [(Any , Any )]](kvIter, {
56+ // Close the stream once all the records have been read from it
57+ wrappedStream.close()
58+ })
5459 }
5560
5661 // Update read metrics for each record materialized
0 commit comments