Skip to content

Commit 164489d

Browse files
committed
Relax assumptions on compressors and serializers when batching
This commit introduces an intermediate layer of an input stream on the batch level. This guards against interference from higher level streams (i.e. compression and deserialization streams), especially pre-fetching, without specifically targeting particular libraries (Kryo) and forcing shuffle spill compression to use LZF.
1 parent 0386f42 commit 164489d

File tree

2 files changed

+90
-84
lines changed

2 files changed

+90
-84
lines changed

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
6666
* Cumulative time spent performing blocking writes, in ns.
6767
*/
6868
def timeWriting(): Long
69+
70+
/**
71+
* Number of bytes written so far
72+
*/
73+
def bytesWritten: Long
6974
}
7075

7176
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
@@ -183,7 +188,8 @@ private[spark] class DiskBlockObjectWriter(
183188
// Only valid if called after close()
184189
override def timeWriting() = _timeWriting
185190

186-
def bytesWritten: Long = {
191+
// Only valid if called after commit()
192+
override def bytesWritten: Long = {
187193
lastValidPosition - initialPosition
188194
}
189195
}

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 83 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ import scala.collection.mutable.ArrayBuffer
2626
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
2727

2828
import org.apache.spark.{Logging, SparkEnv}
29-
import org.apache.spark.io.LZFCompressionCodec
30-
import org.apache.spark.serializer.{KryoDeserializationStream, Serializer}
31-
import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockObjectWriter}
29+
import org.apache.spark.serializer.Serializer
30+
import org.apache.spark.storage.{BlockId, BlockManager}
3231

3332
/**
3433
* An append-only map that spills sorted content to disk when there is insufficient space for it
@@ -84,12 +83,14 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
8483
// Number of in-memory pairs inserted before tracking the map's shuffle memory usage
8584
private val trackMemoryThreshold = 1000
8685

87-
// Size of object batches when reading/writing from serializers. Objects are written in
88-
// batches, with each batch using its own serialization stream. This cuts down on the size
89-
// of reference-tracking maps constructed when deserializing a stream.
90-
//
91-
// NOTE: Setting this too low can cause excess copying when serializing, since some serializers
92-
// grow internal data structures by growing + copying every time the number of objects doubles.
86+
/* Size of object batches when reading/writing from serializers.
87+
*
88+
* Objects are written in batches, with each batch using its own serialization stream. This
89+
* cuts down on the size of reference-tracking maps constructed when deserializing a stream.
90+
*
91+
* NOTE: Setting this too low can cause excess copying when serializing, since some serializers
92+
* grow internal data structures by growing + copying every time the number of objects doubles.
93+
*/
9394
private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
9495

9596
// How many times we have spilled so far
@@ -100,7 +101,6 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
100101
private var _diskBytesSpilled = 0L
101102

102103
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
103-
private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
104104
private val comparator = new KCComparator[K, C]
105105
private val ser = serializer.newInstance()
106106

@@ -153,37 +153,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
153153
logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
154154
.format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
155155
val (blockId, file) = diskBlockManager.createTempBlock()
156+
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
157+
var objectsWritten = 0
156158

157-
/* IMPORTANT NOTE: To avoid having to keep large object graphs in memory, this approach
158-
* closes and re-opens serialization and compression streams within each file. This makes some
159-
* assumptions about the way that serialization and compression streams work, specifically:
160-
*
161-
* 1) The serializer input streams do not pre-fetch data from the underlying stream.
162-
*
163-
* 2) Several compression streams can be opened, written to, and flushed on the write path
164-
* while only one compression input stream is created on the read path
165-
*
166-
* In practice (1) is only true for Java, so we add a special fix below to make it work for
167-
* Kryo. (2) is only true for LZF and not Snappy, so we coerce this to use LZF.
168-
*
169-
* To avoid making these assumptions we should create an intermediate stream that batches
170-
* objects and sends an EOF to the higher layer streams to make sure they never prefetch data.
171-
* This is a bit tricky because, within each segment, you'd need to track the total number
172-
* of bytes written and then re-wind and write it at the beginning of the segment. This will
173-
* most likely require using the file channel API.
174-
*/
159+
// List of batch sizes (bytes) in the order they are written to disk
160+
val batchSizes = new ArrayBuffer[Long]
175161

176-
val shouldCompress = blockManager.shouldCompress(blockId)
177-
val compressionCodec = new LZFCompressionCodec(sparkConf)
178-
def wrapForCompression(outputStream: OutputStream) = {
179-
if (shouldCompress) compressionCodec.compressedOutputStream(outputStream) else outputStream
162+
// Flush the disk writer's contents to disk, and update relevant variables
163+
def flush() = {
164+
writer.commit()
165+
val bytesWritten = writer.bytesWritten
166+
batchSizes.append(bytesWritten)
167+
_diskBytesSpilled += bytesWritten
168+
objectsWritten = 0
180169
}
181170

182-
def getNewWriter = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize,
183-
wrapForCompression, syncWrites)
184-
185-
var writer = getNewWriter
186-
var objectsWritten = 0
187171
try {
188172
val it = currentMap.destructiveSortedIterator(comparator)
189173
while (it.hasNext) {
@@ -192,22 +176,21 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
192176
objectsWritten += 1
193177

194178
if (objectsWritten == serializerBatchSize) {
195-
writer.commit()
179+
flush()
196180
writer.close()
197-
_diskBytesSpilled += writer.bytesWritten
198-
writer = getNewWriter
199-
objectsWritten = 0
181+
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
200182
}
201183
}
202-
203-
if (objectsWritten > 0) writer.commit()
184+
if (objectsWritten > 0) {
185+
flush()
186+
}
204187
} finally {
205188
// Partial failures cannot be tolerated; do not revert partial writes
206189
writer.close()
207-
_diskBytesSpilled += writer.bytesWritten
208190
}
191+
209192
currentMap = new SizeTrackingAppendOnlyMap[K, C]
210-
spilledMaps.append(new DiskMapIterator(file, blockId))
193+
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
211194

212195
// Reset the amount of shuffle memory used by this map in the global pool
213196
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
@@ -252,8 +235,9 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
252235
}
253236

254237
/**
255-
* Fetch from the given iterator until a key of different hash is retrieved. In the
256-
* event of key hash collisions, this ensures no pairs are hidden from being merged.
238+
* Fetch from the given iterator until a key of different hash is retrieved.
239+
*
240+
* In the event of key hash collisions, this ensures no pairs are hidden from being merged.
257241
* Assume the given iterator is in sorted order.
258242
*/
259243
def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
@@ -293,7 +277,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
293277
override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)
294278

295279
/**
296-
* Select a key with the minimum hash, then combine all values with the same key from all input streams.
280+
* Select a key with the minimum hash, then combine all values with the same key from all
281+
* input streams
297282
*/
298283
override def next(): (K, C) = {
299284
// Select a key from the StreamBuffer that holds the lowest key hash
@@ -355,51 +340,66 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
355340
/**
356341
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
357342
*/
358-
private class DiskMapIterator(file: File, blockId: BlockId) extends Iterator[(K, C)] {
343+
private class DiskMapIterator(file: File,
344+
blockId: BlockId,
345+
batchSizes: ArrayBuffer[Long]) extends Iterator[(K, C)] {
359346
val fileStream = new FileInputStream(file)
360347
val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)
361348

362-
val shouldCompress = blockManager.shouldCompress(blockId)
363-
val compressionCodec = new LZFCompressionCodec(sparkConf)
364-
val compressedStream =
365-
if (shouldCompress) {
366-
compressionCodec.compressedInputStream(bufferedStream)
367-
} else {
368-
bufferedStream
369-
}
370-
var deserializeStream = ser.deserializeStream(compressedStream)
371-
var objectsRead = 0
349+
// An intermediate stream that holds all the bytes from exactly one batch
350+
// This guards against pre-fetching and other arbitrary behavior of higher level streams
351+
var batchStream = nextBatchStream(bufferedStream)
372352

353+
var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
354+
var deserializeStream = ser.deserializeStream(compressedStream)
373355
var nextItem: (K, C) = null
374356
var eof = false
375357

358+
/**
359+
* Construct a stream that contains all the bytes from the next batch
360+
*/
361+
def nextBatchStream(stream: InputStream): ByteArrayInputStream = {
362+
var batchBytes = Array[Byte]()
363+
if (batchSizes.length > 0) {
364+
val batchSize = batchSizes.remove(0)
365+
366+
// Read batchSize number of bytes into batchBytes
367+
while (batchBytes.length < batchSize) {
368+
val numBytesToRead = Math.min(8192, batchSize - batchBytes.length).toInt
369+
val bytesRead = new Array[Byte](numBytesToRead)
370+
stream.read(bytesRead, 0, numBytesToRead)
371+
batchBytes ++= bytesRead
372+
}
373+
} else {
374+
// No more batches left
375+
eof = true
376+
}
377+
new ByteArrayInputStream(batchBytes)
378+
}
379+
380+
/**
381+
* Return the next (K, C) pair from the deserialization stream.
382+
*
383+
* If the underlying batch stream is drained, construct a new stream for the next batch
384+
* (if there is one) and stream from it. If there are no more batches left, return null.
385+
*/
376386
def readNextItem(): (K, C) = {
377-
if (!eof) {
378-
try {
379-
if (objectsRead == serializerBatchSize) {
380-
val newInputStream = deserializeStream match {
381-
case stream: KryoDeserializationStream =>
382-
// Kryo's serializer stores an internal buffer that pre-fetches from the underlying
383-
// stream. We need to capture this buffer and feed it to the new serialization
384-
// stream so that the bytes are not lost.
385-
val kryoInput = stream.input
386-
val remainingBytes = kryoInput.limit() - kryoInput.position()
387-
val extraBuf = kryoInput.readBytes(remainingBytes)
388-
new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream)
389-
case _ => compressedStream
390-
}
391-
deserializeStream = ser.deserializeStream(newInputStream)
392-
objectsRead = 0
393-
}
394-
objectsRead += 1
395-
return deserializeStream.readObject().asInstanceOf[(K, C)]
396-
} catch {
397-
case e: EOFException =>
398-
eof = true
387+
try {
388+
deserializeStream.readObject().asInstanceOf[(K, C)]
389+
} catch {
390+
// End of current batch
391+
case e: EOFException =>
392+
batchStream = nextBatchStream(bufferedStream)
393+
if (!eof) {
394+
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
395+
deserializeStream = ser.deserializeStream(compressedStream)
396+
readNextItem()
397+
} else {
398+
// No more batches left
399399
cleanup()
400-
}
400+
null
401+
}
401402
}
402-
null
403403
}
404404

405405
override def hasNext: Boolean = {

0 commit comments

Comments
 (0)