Skip to content

Commit 01a881e

Browse files
sryzanemccarthy
authored andcommitted
[SPARK-4550] In sort-based shuffle, store map outputs in serialized form
Refer to the JIRA for the design doc and some perf results. I wanted to call out some of the more possibly controversial changes up front: * Map outputs are only stored in serialized form when Kryo is in use. I'm still unsure whether Java-serialized objects can be relocated. At the very least, Java serialization writes out a stream header which causes problems with the current approach, so I decided to leave investigating this to future work. * The shuffle now explicitly operates on key-value pairs instead of any object. Data is written to shuffle files in alternating keys and values instead of key-value tuples. `BlockObjectWriter.write` now accepts a key argument and a value argument instead of any object. * The map output buffer can hold a max of Integer.MAX_VALUE bytes. Though this wouldn't be terribly difficult to change. * When spilling occurs, the objects that still in memory at merge time end up serialized and deserialized an extra time. Author: Sandy Ryza <[email protected]> Closes apache#4450 from sryza/sandy-spark-4550 and squashes the following commits: 8c70dd9 [Sandy Ryza] Fix serialization 9c16fe6 [Sandy Ryza] Fix a couple tests and move getAutoReset to KryoSerializerInstance 6c54e06 [Sandy Ryza] Fix scalastyle d8462d8 [Sandy Ryza] SPARK-4550
1 parent 8311b6f commit 01a881e

23 files changed

+1240
-190
lines changed

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
200200
override def deserializeStream(s: InputStream): DeserializationStream = {
201201
new KryoDeserializationStream(kryo, s)
202202
}
203+
204+
/**
205+
* Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
206+
* registrator explicitly turns auto-reset off.
207+
*/
208+
def getAutoReset(): Boolean = {
209+
val field = classOf[Kryo].getDeclaredField("autoReset")
210+
field.setAccessible(true)
211+
field.get(kryo).asInstanceOf[Boolean]
212+
}
203213
}
204214

205215
/**

core/src/main/scala/org/apache/spark/serializer/Serializer.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,12 @@ abstract class SerializerInstance {
101101
*/
102102
@DeveloperApi
103103
abstract class SerializationStream {
104+
/** The most general-purpose method to write an object. */
104105
def writeObject[T: ClassTag](t: T): SerializationStream
106+
/** Writes the object representing the key of a key-value pair. */
107+
def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
108+
/** Writes the object representing the value of a key-value pair. */
109+
def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
105110
def flush(): Unit
106111
def close(): Unit
107112

@@ -120,7 +125,12 @@ abstract class SerializationStream {
120125
*/
121126
@DeveloperApi
122127
abstract class DeserializationStream {
128+
/** The most general-purpose method to read an object. */
123129
def readObject[T: ClassTag](): T
130+
/** Reads the object representing the key of a key-value pair. */
131+
def readKey[T: ClassTag](): T = readObject[T]()
132+
/** Reads the object representing the value of a key-value pair. */
133+
def readValue[T: ClassTag](): T = readObject[T]()
124134
def close(): Unit
125135

126136
/**
@@ -141,4 +151,25 @@ abstract class DeserializationStream {
141151
DeserializationStream.this.close()
142152
}
143153
}
154+
155+
/**
156+
* Read the elements of this stream through an iterator over key-value pairs. This can only be
157+
* called once, as reading each element will consume data from the input source.
158+
*/
159+
def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
160+
override protected def getNext() = {
161+
try {
162+
(readKey[Any](), readValue[Any]())
163+
} catch {
164+
case eof: EOFException => {
165+
finished = true
166+
null
167+
}
168+
}
169+
}
170+
171+
override protected def close() {
172+
DeserializationStream.this.close()
173+
}
174+
}
144175
}

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V](
6363

6464
for (elem <- iter) {
6565
val bucketId = dep.partitioner.getPartition(elem._1)
66-
shuffle.writers(bucketId).write(elem)
66+
shuffle.writers(bucketId).write(elem._1, elem._2)
6767
}
6868
}
6969

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.util.Utils
3333
* This interface does not support concurrent writes. Also, once the writer has
3434
* been opened, it cannot be reopened again.
3535
*/
36-
private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
36+
private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
3737

3838
def open(): BlockObjectWriter
3939

@@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
5454
def revertPartialWritesAndClose()
5555

5656
/**
57-
* Writes an object.
57+
* Writes a key-value pair.
5858
*/
59-
def write(value: Any)
59+
def write(key: Any, value: Any)
60+
61+
/**
62+
* Notify the writer that a record worth of bytes has been written with writeBytes.
63+
*/
64+
def recordWritten()
6065

6166
/**
6267
* Returns the file segment of committed data that this Writer has written.
@@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter(
203208
}
204209
}
205210

206-
override def write(value: Any) {
211+
override def write(key: Any, value: Any) {
212+
if (!initialized) {
213+
open()
214+
}
215+
216+
objOut.writeKey(key)
217+
objOut.writeValue(value)
218+
numRecordsWritten += 1
219+
writeMetrics.incShuffleRecordsWritten(1)
220+
221+
if (numRecordsWritten % 32 == 0) {
222+
updateBytesWritten()
223+
}
224+
}
225+
226+
override def write(b: Int): Unit = throw new UnsupportedOperationException()
227+
228+
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
207229
if (!initialized) {
208230
open()
209231
}
210232

211-
objOut.writeObject(value)
233+
bs.write(kvBytes, offs, len)
234+
}
235+
236+
override def recordWritten(): Unit = {
212237
numRecordsWritten += 1
213238
writeMetrics.incShuffleRecordsWritten(1)
214239

@@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter(
238263
}
239264

240265
// For testing
241-
private[spark] def flush() {
266+
private[spark] override def flush() {
242267
objOut.flush()
243268
bs.flush()
244269
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.{InputStream, IOException}
2120
import java.util.concurrent.LinkedBlockingQueue
2221

2322
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
24-
import scala.util.{Failure, Success, Try}
23+
import scala.util.{Failure, Try}
2524

2625
import org.apache.spark.{Logging, TaskContext}
27-
import org.apache.spark.network.BlockTransferService
2826
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
2927
import org.apache.spark.network.buffer.ManagedBuffer
3028
import org.apache.spark.serializer.{SerializerInstance, Serializer}
@@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
301299
// the scheduler gets a FetchFailedException.
302300
Try(buf.createInputStream()).map { is0 =>
303301
val is = blockManager.wrapForCompression(blockId, is0)
304-
val iter = serializerInstance.deserializeStream(is).asIterator
302+
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
305303
CompletionIterator[Any, Iterator[Any]](iter, {
306304
// Once the iterator is exhausted, release the buffer and set currentResult to null
307305
// so we don't release it again in cleanup.
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
import java.io.OutputStream
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
24+
/**
25+
* A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
26+
* advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
27+
* of memory and needing to copy the full contents. The disadvantage is that the contents don't
28+
* occupy a contiguous segment of memory.
29+
*/
30+
private[spark] class ChainedBuffer(chunkSize: Int) {
31+
private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
32+
assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
33+
s"ChainedBuffer chunk size $chunkSize must be a power of two")
34+
private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
35+
private var _size: Int = _
36+
37+
/**
38+
* Feed bytes from this buffer into a BlockObjectWriter.
39+
*
40+
* @param pos Offset in the buffer to read from.
41+
* @param os OutputStream to read into.
42+
* @param len Number of bytes to read.
43+
*/
44+
def read(pos: Int, os: OutputStream, len: Int): Unit = {
45+
if (pos + len > _size) {
46+
throw new IndexOutOfBoundsException(
47+
s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
48+
}
49+
var chunkIndex = pos >> chunkSizeLog2
50+
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
51+
var written = 0
52+
while (written < len) {
53+
val toRead = math.min(len - written, chunkSize - posInChunk)
54+
os.write(chunks(chunkIndex), posInChunk, toRead)
55+
written += toRead
56+
chunkIndex += 1
57+
posInChunk = 0
58+
}
59+
}
60+
61+
/**
62+
* Read bytes from this buffer into a byte array.
63+
*
64+
* @param pos Offset in the buffer to read from.
65+
* @param bytes Byte array to read into.
66+
* @param offs Offset in the byte array to read to.
67+
* @param len Number of bytes to read.
68+
*/
69+
def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
70+
if (pos + len > _size) {
71+
throw new IndexOutOfBoundsException(
72+
s"Read of $len bytes at position $pos would go past size of buffer")
73+
}
74+
var chunkIndex = pos >> chunkSizeLog2
75+
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
76+
var written = 0
77+
while (written < len) {
78+
val toRead = math.min(len - written, chunkSize - posInChunk)
79+
System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
80+
written += toRead
81+
chunkIndex += 1
82+
posInChunk = 0
83+
}
84+
}
85+
86+
/**
87+
* Write bytes from a byte array into this buffer.
88+
*
89+
* @param pos Offset in the buffer to write to.
90+
* @param bytes Byte array to write from.
91+
* @param offs Offset in the byte array to write from.
92+
* @param len Number of bytes to write.
93+
*/
94+
def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
95+
if (pos > _size) {
96+
throw new IndexOutOfBoundsException(
97+
s"Write at position $pos starts after end of buffer ${_size}")
98+
}
99+
// Grow if needed
100+
val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
101+
while (endChunkIndex >= chunks.length) {
102+
chunks += new Array[Byte](chunkSize)
103+
}
104+
105+
var chunkIndex = pos >> chunkSizeLog2
106+
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
107+
var written = 0
108+
while (written < len) {
109+
val toWrite = math.min(len - written, chunkSize - posInChunk)
110+
System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
111+
written += toWrite
112+
chunkIndex += 1
113+
posInChunk = 0
114+
}
115+
116+
_size = math.max(_size, pos + len)
117+
}
118+
119+
/**
120+
* Total size of buffer that can be written to without allocating additional memory.
121+
*/
122+
def capacity: Int = chunks.size * chunkSize
123+
124+
/**
125+
* Size of the logical buffer.
126+
*/
127+
def size: Int = _size
128+
}
129+
130+
/**
131+
* Output stream that writes to a ChainedBuffer.
132+
*/
133+
private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
134+
private var pos = 0
135+
136+
override def write(b: Int): Unit = {
137+
throw new UnsupportedOperationException()
138+
}
139+
140+
override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
141+
chainedBuffer.write(pos, bytes, offs, len)
142+
pos += len
143+
}
144+
}

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class ExternalAppendOnlyMap[K, V, C](
174174
val it = currentMap.destructiveSortedIterator(keyComparator)
175175
while (it.hasNext) {
176176
val kv = it.next()
177-
writer.write(kv)
177+
writer.write(kv._1, kv._2)
178178
objectsWritten += 1
179179

180180
if (objectsWritten == serializerBatchSize) {
@@ -435,7 +435,9 @@ class ExternalAppendOnlyMap[K, V, C](
435435
*/
436436
private def readNextItem(): (K, C) = {
437437
try {
438-
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
438+
val k = deserializeStream.readKey().asInstanceOf[K]
439+
val c = deserializeStream.readValue().asInstanceOf[C]
440+
val item = (k, c)
439441
objectsRead += 1
440442
if (objectsRead == serializerBatchSize) {
441443
objectsRead = 0

0 commit comments

Comments
 (0)