@@ -26,9 +26,8 @@ import scala.collection.mutable.ArrayBuffer
2626import it .unimi .dsi .fastutil .io .FastBufferedInputStream
2727
2828import org .apache .spark .{Logging , SparkEnv }
29- import org .apache .spark .io .LZFCompressionCodec
30- import org .apache .spark .serializer .{KryoDeserializationStream , Serializer }
31- import org .apache .spark .storage .{BlockId , BlockManager , DiskBlockObjectWriter }
29+ import org .apache .spark .serializer .Serializer
30+ import org .apache .spark .storage .{BlockId , BlockManager }
3231
3332/**
3433 * An append-only map that spills sorted content to disk when there is insufficient space for it
@@ -84,12 +83,14 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
8483 // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
8584 private val trackMemoryThreshold = 1000
8685
87- // Size of object batches when reading/writing from serializers. Objects are written in
88- // batches, with each batch using its own serialization stream. This cuts down on the size
89- // of reference-tracking maps constructed when deserializing a stream.
90- //
91- // NOTE: Setting this too low can cause excess copying when serializing, since some serializers
92- // grow internal data structures by growing + copying every time the number of objects doubles.
86+ /* Size of object batches when reading/writing from serializers.
87+ *
88+ * Objects are written in batches, with each batch using its own serialization stream. This
89+ * cuts down on the size of reference-tracking maps constructed when deserializing a stream.
90+ *
91+ * NOTE: Setting this too low can cause excess copying when serializing, since some serializers
92+ * grow internal data structures by growing + copying every time the number of objects doubles.
93+ */
9394 private val serializerBatchSize = sparkConf.getLong(" spark.shuffle.spill.batchSize" , 10000 )
9495
9596 // How many times we have spilled so far
@@ -100,7 +101,6 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
100101 private var _diskBytesSpilled = 0L
101102
102103 private val fileBufferSize = sparkConf.getInt(" spark.shuffle.file.buffer.kb" , 100 ) * 1024
103- private val syncWrites = sparkConf.getBoolean(" spark.shuffle.sync" , false )
104104 private val comparator = new KCComparator [K , C ]
105105 private val ser = serializer.newInstance()
106106
@@ -153,37 +153,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
153153 logWarning(" Spilling in-memory map of %d MB to disk (%d time%s so far)"
154154 .format(mapSize / (1024 * 1024 ), spillCount, if (spillCount > 1 ) " s" else " " ))
155155 val (blockId, file) = diskBlockManager.createTempBlock()
156+ var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
157+ var objectsWritten = 0
156158
157- /* IMPORTANT NOTE: To avoid having to keep large object graphs in memory, this approach
158- * closes and re-opens serialization and compression streams within each file. This makes some
159- * assumptions about the way that serialization and compression streams work, specifically:
160- *
161- * 1) The serializer input streams do not pre-fetch data from the underlying stream.
162- *
163- * 2) Several compression streams can be opened, written to, and flushed on the write path
164- * while only one compression input stream is created on the read path
165- *
166- * In practice (1) is only true for Java, so we add a special fix below to make it work for
167- * Kryo. (2) is only true for LZF and not Snappy, so we coerce this to use LZF.
168- *
169- * To avoid making these assumptions we should create an intermediate stream that batches
170- * objects and sends an EOF to the higher layer streams to make sure they never prefetch data.
171- * This is a bit tricky because, within each segment, you'd need to track the total number
172- * of bytes written and then re-wind and write it at the beginning of the segment. This will
173- * most likely require using the file channel API.
174- */
159+ // List of batch sizes (bytes) in the order they are written to disk
160+ val batchSizes = new ArrayBuffer [Long ]
175161
176- val shouldCompress = blockManager.shouldCompress(blockId)
177- val compressionCodec = new LZFCompressionCodec (sparkConf)
178- def wrapForCompression (outputStream : OutputStream ) = {
179- if (shouldCompress) compressionCodec.compressedOutputStream(outputStream) else outputStream
162+ // Flush the disk writer's contents to disk, and update relevant variables
163+ def flush () = {
164+ writer.commit()
165+ val bytesWritten = writer.bytesWritten
166+ batchSizes.append(bytesWritten)
167+ _diskBytesSpilled += bytesWritten
168+ objectsWritten = 0
180169 }
181170
182- def getNewWriter = new DiskBlockObjectWriter (blockId, file, serializer, fileBufferSize,
183- wrapForCompression, syncWrites)
184-
185- var writer = getNewWriter
186- var objectsWritten = 0
187171 try {
188172 val it = currentMap.destructiveSortedIterator(comparator)
189173 while (it.hasNext) {
@@ -192,22 +176,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
192176 objectsWritten += 1
193177
194178 if (objectsWritten == serializerBatchSize) {
195- writer.commit ()
179+ flush ()
196180 writer.close()
197- _diskBytesSpilled += writer.bytesWritten
198- writer = getNewWriter
199- objectsWritten = 0
181+ writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
200182 }
201183 }
202-
203- if (objectsWritten > 0 ) writer.commit()
184+ if (objectsWritten > 0 ) {
185+ flush()
186+ }
204187 } finally {
205188 // Partial failures cannot be tolerated; do not revert partial writes
206189 writer.close()
207- _diskBytesSpilled += writer.bytesWritten
208190 }
191+
209192 currentMap = new SizeTrackingAppendOnlyMap [K , C ]
210- spilledMaps.append(new DiskMapIterator (file, blockId))
193+ spilledMaps.append(new DiskMapIterator (file, blockId, batchSizes ))
211194
212195 // Reset the amount of shuffle memory used by this map in the global pool
213196 val shuffleMemoryMap = SparkEnv .get.shuffleMemoryMap
@@ -252,8 +235,9 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
252235 }
253236
254237 /**
255- * Fetch from the given iterator until a key of different hash is retrieved. In the
256- * event of key hash collisions, this ensures no pairs are hidden from being merged.
238+ * Fetch from the given iterator until a key of different hash is retrieved.
239+ *
240+ * In the event of key hash collisions, this ensures no pairs are hidden from being merged.
257241 * Assume the given iterator is in sorted order.
258242 */
259243 def getMorePairs (it : Iterator [(K , C )]): ArrayBuffer [(K , C )] = {
@@ -293,7 +277,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
293277 override def hasNext : Boolean = mergeHeap.exists(! _.pairs.isEmpty)
294278
295279 /**
296- * Select a key with the minimum hash, then combine all values with the same key from all input streams.
280+ * Select a key with the minimum hash, then combine all values with the same key from all
281+ * input streams
297282 */
298283 override def next (): (K , C ) = {
299284 // Select a key from the StreamBuffer that holds the lowest key hash
@@ -355,51 +340,66 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
355340 /**
356341 * An iterator that returns (K, C) pairs in sorted order from an on-disk map
357342 */
358- private class DiskMapIterator (file : File , blockId : BlockId ) extends Iterator [(K , C )] {
343+ private class DiskMapIterator (file : File ,
344+ blockId : BlockId ,
345+ batchSizes : ArrayBuffer [Long ]) extends Iterator [(K , C )] {
359346 val fileStream = new FileInputStream (file)
360347 val bufferedStream = new FastBufferedInputStream (fileStream, fileBufferSize)
361348
362- val shouldCompress = blockManager.shouldCompress(blockId)
363- val compressionCodec = new LZFCompressionCodec (sparkConf)
364- val compressedStream =
365- if (shouldCompress) {
366- compressionCodec.compressedInputStream(bufferedStream)
367- } else {
368- bufferedStream
369- }
370- var deserializeStream = ser.deserializeStream(compressedStream)
371- var objectsRead = 0
349+ // An intermediate stream that holds all the bytes from exactly one batch
350+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
351+ var batchStream = nextBatchStream(bufferedStream)
372352
353+ var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
354+ var deserializeStream = ser.deserializeStream(compressedStream)
373355 var nextItem : (K , C ) = null
374356 var eof = false
375357
358+ /**
359+ * Construct a stream that contains all the bytes from the next batch
360+ */
361+ def nextBatchStream (stream : InputStream ): ByteArrayInputStream = {
362+ var batchBytes = Array [Byte ]()
363+ if (batchSizes.length > 0 ) {
364+ val batchSize = batchSizes.remove(0 )
365+
366+ // Read batchSize number of bytes into batchBytes
367+ while (batchBytes.length < batchSize) {
368+ val numBytesToRead = Math .min(8192 , batchSize - batchBytes.length).toInt
369+ val bytesRead = new Array [Byte ](numBytesToRead)
370+ stream.read(bytesRead, 0 , numBytesToRead)
371+ batchBytes ++= bytesRead
372+ }
373+ } else {
374+ // No more batches left
375+ eof = true
376+ }
377+ new ByteArrayInputStream (batchBytes)
378+ }
379+
380+ /**
381+ * Return the next (K, C) pair from the deserialization stream.
382+ *
383+ * If the underlying batch stream is drained, construct a new stream for the next batch
384+ * (if there is one) and stream from it. If there are no more batches left, return null.
385+ */
376386 def readNextItem (): (K , C ) = {
377- if (! eof) {
378- try {
379- if (objectsRead == serializerBatchSize) {
380- val newInputStream = deserializeStream match {
381- case stream : KryoDeserializationStream =>
382- // Kryo's serializer stores an internal buffer that pre-fetches from the underlying
383- // stream. We need to capture this buffer and feed it to the new serialization
384- // stream so that the bytes are not lost.
385- val kryoInput = stream.input
386- val remainingBytes = kryoInput.limit() - kryoInput.position()
387- val extraBuf = kryoInput.readBytes(remainingBytes)
388- new SequenceInputStream (new ByteArrayInputStream (extraBuf), compressedStream)
389- case _ => compressedStream
390- }
391- deserializeStream = ser.deserializeStream(newInputStream)
392- objectsRead = 0
393- }
394- objectsRead += 1
395- return deserializeStream.readObject().asInstanceOf [(K , C )]
396- } catch {
397- case e : EOFException =>
398- eof = true
387+ try {
388+ deserializeStream.readObject().asInstanceOf [(K , C )]
389+ } catch {
390+ // End of current batch
391+ case e : EOFException =>
392+ batchStream = nextBatchStream(bufferedStream)
393+ if (! eof) {
394+ compressedStream = blockManager.wrapForCompression(blockId, batchStream)
395+ deserializeStream = ser.deserializeStream(compressedStream)
396+ readNextItem()
397+ } else {
398+ // No more batches left
399399 cleanup()
400- }
400+ null
401+ }
401402 }
402- null
403403 }
404404
405405 override def hasNext : Boolean = {
0 commit comments