Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserializeStream(s: InputStream): DeserializationStream = {
new KryoDeserializationStream(kryo, s)
}

/**
* Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
* registrator explicitly turns auto-reset off.
*/
def getAutoReset(): Boolean = {
val field = classOf[Kryo].getDeclaredField("autoReset")
field.setAccessible(true)
field.get(kryo).asInstanceOf[Boolean]
}
}

/**
Expand Down
31 changes: 31 additions & 0 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ abstract class SerializerInstance {
*/
@DeveloperApi
abstract class SerializationStream {
/** The most general-purpose method to write an object. */
def writeObject[T: ClassTag](t: T): SerializationStream
/** Writes the object representing the key of a key-value pair. */
def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
/** Writes the object representing the value of a key-value pair. */
def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
def flush(): Unit
def close(): Unit

Expand All @@ -120,7 +125,12 @@ abstract class SerializationStream {
*/
@DeveloperApi
abstract class DeserializationStream {
/** The most general-purpose method to read an object. */
def readObject[T: ClassTag](): T
/** Reads the object representing the key of a key-value pair. */
def readKey[T: ClassTag](): T = readObject[T]()
/** Reads the object representing the value of a key-value pair. */
def readValue[T: ClassTag](): T = readObject[T]()
def close(): Unit

/**
Expand All @@ -141,4 +151,25 @@ abstract class DeserializationStream {
DeserializationStream.this.close()
}
}

/**
* Read the elements of this stream through an iterator over key-value pairs. This can only be
* called once, as reading each element will consume data from the input source.
*/
def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
override protected def getNext() = {
try {
(readKey[Any](), readValue[Any]())
} catch {
case eof: EOFException => {
finished = true
null
}
}
}

override protected def close() {
DeserializationStream.this.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V](

for (elem <- iter) {
val bucketId = dep.partitioner.getPartition(elem._1)
shuffle.writers(bucketId).write(elem)
shuffle.writers(bucketId).write(elem._1, elem._2)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.util.Utils
* This interface does not support concurrent writes. Also, once the writer has
* been opened, it cannot be reopened again.
*/
private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {

def open(): BlockObjectWriter

Expand All @@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
def revertPartialWritesAndClose()

/**
* Writes an object.
* Writes a key-value pair.
*/
def write(value: Any)
def write(key: Any, value: Any)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes all seem good - but FYI I think they overlap with another outstanding patch


/**
* Notify the writer that a record worth of bytes has been written with writeBytes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is out of date; there isn't a writeBytes method in the general BlockObjectWriter interface, only in DiskBlockObjectWriter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see that this has changed now that BlockObjectWriter extends OutputStream.

*/
def recordWritten()

/**
* Returns the file segment of committed data that this Writer has written.
Expand Down Expand Up @@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter(
}
}

override def write(value: Any) {
override def write(key: Any, value: Any) {
if (!initialized) {
open()
}

objOut.writeKey(key)
objOut.writeValue(value)
numRecordsWritten += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid code duplication, I think we could have called recordWritten() here instead, which would handle updating the bytes written etc.

writeMetrics.incShuffleRecordsWritten(1)

if (numRecordsWritten % 32 == 0) {
updateBytesWritten()
}
}

override def write(b: Int): Unit = throw new UnsupportedOperationException()

override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!initialized) {
open()
}

objOut.writeObject(value)
bs.write(kvBytes, offs, len)
}

override def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incShuffleRecordsWritten(1)

Expand Down Expand Up @@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter(
}

// For testing
private[spark] def flush() {
private[spark] override def flush() {
objOut.flush()
bs.flush()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

package org.apache.spark.storage

import java.io.{InputStream, IOException}
import java.util.concurrent.LinkedBlockingQueue

import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.{Failure, Success, Try}
import scala.util.{Failure, Try}

import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
Expand Down Expand Up @@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
val iter = serializerInstance.deserializeStream(is).asIterator
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import java.io.OutputStream

import scala.collection.mutable.ArrayBuffer

/**
* A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
* advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
* of memory and needing to copy the full contents. The disadvantage is that the contents don't
* occupy a contiguous segment of memory.
*/
private[spark] class ChainedBuffer(chunkSize: Int) {
private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
s"ChainedBuffer chunk size $chunkSize must be a power of two")
private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
private var _size: Int = _

/**
* Feed bytes from this buffer into a BlockObjectWriter.
*
* @param pos Offset in the buffer to read from.
* @param os OutputStream to read into.
* @param len Number of bytes to read.
*/
def read(pos: Int, os: OutputStream, len: Int): Unit = {
if (pos + len > _size) {
throw new IndexOutOfBoundsException(
s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
}
var chunkIndex = pos >> chunkSizeLog2
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
var written = 0
while (written < len) {
val toRead = math.min(len - written, chunkSize - posInChunk)
os.write(chunks(chunkIndex), posInChunk, toRead)
written += toRead
chunkIndex += 1
posInChunk = 0
}
}

/**
* Read bytes from this buffer into a byte array.
*
* @param pos Offset in the buffer to read from.
* @param bytes Byte array to read into.
* @param offs Offset in the byte array to read to.
* @param len Number of bytes to read.
*/
def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
if (pos + len > _size) {
throw new IndexOutOfBoundsException(
s"Read of $len bytes at position $pos would go past size of buffer")
}
var chunkIndex = pos >> chunkSizeLog2
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
var written = 0
while (written < len) {
val toRead = math.min(len - written, chunkSize - posInChunk)
System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
written += toRead
chunkIndex += 1
posInChunk = 0
}
}

/**
* Write bytes from a byte array into this buffer.
*
* @param pos Offset in the buffer to write to.
* @param bytes Byte array to write from.
* @param offs Offset in the byte array to write from.
* @param len Number of bytes to write.
*/
def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
if (pos > _size) {
throw new IndexOutOfBoundsException(
s"Write at position $pos starts after end of buffer ${_size}")
}
// Grow if needed
val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
while (endChunkIndex >= chunks.length) {
chunks += new Array[Byte](chunkSize)
}

var chunkIndex = pos >> chunkSizeLog2
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
var written = 0
while (written < len) {
val toWrite = math.min(len - written, chunkSize - posInChunk)
System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
written += toWrite
chunkIndex += 1
posInChunk = 0
}

_size = math.max(_size, pos + len)
}

/**
* Total size of buffer that can be written to without allocating additional memory.
*/
def capacity: Int = chunks.size * chunkSize

/**
* Size of the logical buffer.
*/
def size: Int = _size
}

/**
* Output stream that writes to a ChainedBuffer.
*/
private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
private var pos = 0

override def write(b: Int): Unit = {
throw new UnsupportedOperationException()
}

override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
chainedBuffer.write(pos, bytes, offs, len)
pos += len
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class ExternalAppendOnlyMap[K, V, C](
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
writer.write(kv)
writer.write(kv._1, kv._2)
objectsWritten += 1

if (objectsWritten == serializerBatchSize) {
Expand Down Expand Up @@ -433,7 +433,9 @@ class ExternalAppendOnlyMap[K, V, C](
*/
private def readNextItem(): (K, C) = {
try {
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
val k = deserializeStream.readKey().asInstanceOf[K]
val c = deserializeStream.readValue().asInstanceOf[C]
val item = (k, c)
objectsRead += 1
if (objectsRead == serializerBatchSize) {
objectsRead = 0
Expand Down
Loading