From a68171cea01ce90ee41665bc1ee5bc6eaca03e32 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Nov 2016 15:16:10 -0800 Subject: [PATCH] checksum --- .../buffer/FileSegmentManagedBuffer.java | 28 ++++++-- .../spark/network/buffer/ManagedBuffer.java | 2 +- .../network/buffer/NettyManagedBuffer.java | 17 ++++- .../network/buffer/NioManagedBuffer.java | 19 +++++- .../spark/network/TestManagedBuffer.java | 4 +- .../spark/network/sasl/SparkSaslSuite.java | 12 ++-- .../ExternalShuffleBlockResolverSuite.java | 13 ++-- .../sort/BypassMergeSortShuffleWriter.java | 2 +- .../shuffle/sort/ShuffleExternalSorter.java | 4 +- .../shuffle/sort/UnsafeShuffleWriter.java | 25 +++++-- .../spark/storage/ChecksumOutputStream.java | 66 +++++++++++++++++++ .../unsafe/sort/UnsafeSorterSpillWriter.java | 4 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 6 +- .../spark/storage/DiskBlockObjectWriter.scala | 11 +++- .../storage/ShuffleBlockFetcherIterator.scala | 45 ++++++++++--- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../util/collection/ExternalSorter.scala | 4 +- .../sort/UnsafeShuffleWriterSuite.java | 9 +-- .../map/AbstractBytesToBytesMapSuite.java | 7 +- .../sort/UnsafeExternalSorterSuite.java | 6 +- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../BlockStoreShuffleReaderSuite.scala | 3 +- .../BypassMergeSortShuffleWriterSuite.scala | 6 +- .../ShuffleBlockFetcherIteratorSuite.scala | 19 +++--- 25 files changed, 246 insertions(+), 73 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/storage/ChecksumOutputStream.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..92bdd129ddb61 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -17,13 +17,12 @@ package org.apache.spark.network.buffer; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.RandomAccessFile; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.util.zip.Adler32; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -92,12 +91,27 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream createInputStream() throws IOException { + public InputStream createInputStream(boolean checksum) throws IOException { FileInputStream is = null; try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return new LimitedInputStream(is, length); + if (checksum) { + Checksum ck = new Adler32(); + DataInputStream din = new DataInputStream(new CheckedInputStream(is, ck)); + ByteStreams.skipFully(din, length - 8); + long sum = ck.getValue(); + long expected = din.readLong(); + if (sum != expected) { + throw new IOException("Checksum does not match " + sum + "!=" + expected); + } + is.close(); + is = new FileInputStream(file); + ByteStreams.skipFully(is, offset); + return new LimitedInputStream(is, length - 8); + } else { + return new LimitedInputStream(is, length); + } } catch (IOException e) { try { if (is != null) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index 1861f8d7fd8f3..3b587d337b3e3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -51,7 +51,7 @@ public abstract class ManagedBuffer { * necessarily check for the length of bytes read, so the caller is responsible for making sure * it does not go over the limit. */ - public abstract InputStream createInputStream() throws IOException; + public abstract InputStream createInputStream(boolean checksum) throws IOException; /** * Increment the reference count by one if applicable. diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index acc49d968c186..d8c78709b5420 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.util.zip.Adler32; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -46,7 +47,21 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream createInputStream() throws IOException { + public InputStream createInputStream(boolean checksum) throws IOException { + if (checksum) { + Adler32 adler = new Adler32(); + long size = size(); + buf.markReaderIndex(); + for (int i = 0; i < size - 8; i++) { + adler.update(buf.readByte()); + } + long sum = buf.readLong(); + if (adler.getValue() != sum) { + throw new IOException("Checksum does not match " + adler.getValue() + "!=" + sum); + } + buf.resetReaderIndex(); + buf.writerIndex(buf.writerIndex() - 8); + } return new ByteBufInputStream(buf); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java index 631d767715256..af7863b217a6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.util.zip.Adler32; import com.google.common.base.Objects; import io.netty.buffer.ByteBufInputStream; @@ -46,7 +47,23 @@ public ByteBuffer nioByteBuffer() throws IOException { } @Override - public InputStream createInputStream() throws IOException { + public InputStream createInputStream(boolean checksum) throws IOException { + if (checksum) { + Adler32 adler = new Adler32(); + int position = buf.position(); + int limit = buf.limit() - 8; + buf.position(limit); + long sum = buf.getLong(); + buf.position(position); + // simplify this after drop Java 7 support + for (int i=buf.position(); i> records) throws IOException { final File file = tempShuffleBlockIdPlusFile._2(); final BlockId blockId = tempShuffleBlockIdPlusFile._1(); partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics, true); } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index c33d1e33f030f..aecbae17f5576 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -173,7 +173,9 @@ private void writeSortedFile(boolean isLastFile) throws IOException { final SerializerInstance ser = DummySerializerInstance.INSTANCE; final DiskBlockObjectWriter writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse, + // only generate checksum for only spill + isLastFile && spills.isEmpty()); int currentPartition = -1; while (sortedRecords.hasNext()) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c434be7b1..eb942414016c7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -21,6 +21,7 @@ import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; +import java.util.zip.Adler32; import scala.Option; import scala.Product2; @@ -35,7 +36,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.*; +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; import org.apache.spark.annotation.Private; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; @@ -49,6 +53,7 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.ChecksumOutputStream; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; @@ -75,6 +80,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SparkConf sparkConf; private final boolean transferToEnabled; private final int initialSortBufferSize; + private final boolean checksum; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -108,8 +114,8 @@ public UnsafeShuffleWriter( if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + - " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; @@ -124,7 +130,9 @@ public UnsafeShuffleWriter( this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize", - DEFAULT_INITIAL_SORT_BUFFER_SIZE); + DEFAULT_INITIAL_SORT_BUFFER_SIZE); + this.checksum = sparkConf.getBoolean("spark.shuffle.checksum", true); + open(); } @@ -289,7 +297,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !checksum) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -346,8 +354,11 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + OutputStream fos = new FileOutputStream(outputFile, true); + if (checksum) { + fos = new ChecksumOutputStream(fos, new Adler32()); + } + mergedFileOutputStream = new TimeTrackingOutputStream(writeMetrics, fos); if (compressionCodec != null) { mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); } diff --git a/core/src/main/java/org/apache/spark/storage/ChecksumOutputStream.java b/core/src/main/java/org/apache/spark/storage/ChecksumOutputStream.java new file mode 100644 index 0000000000000..2d91884ce9ef5 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/ChecksumOutputStream.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.zip.Checksum; + +/** + * A output stream that generate checksum for written data and write the checksum as long at + * the end of stream. + */ +public class ChecksumOutputStream extends FilterOutputStream { + private Checksum cksum; + private boolean closed; + + public ChecksumOutputStream(OutputStream out, Checksum cksum) { + super(out); + cksum.reset(); + this.cksum = cksum; + this.closed = false; + } + + public void write(int b) throws IOException { + out.write(b); + cksum.update(b); + } + + public void write(byte[] b) throws IOException { + write(b, 0, b.length); + } + + public void write(byte[] b, int off, int len) throws IOException { + out.write(b, off, len); + cksum.update(b, off, len); + } + + public void close() throws IOException { + flush(); + if (!closed) { + closed = true; + ByteBuffer buffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); + buffer.putLong(cksum.getValue()); + out.write(buffer.array()); + out.close(); + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 164b9d70b79d7..a240abeaeca47 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,11 +20,11 @@ import java.io.File; import java.io.IOException; -import org.apache.spark.serializer.SerializerManager; import scala.Tuple2; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; @@ -66,7 +66,7 @@ public UnsafeSorterSpillWriter( // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work // around this, we pass a dummy no-op serializer. writer = blockManager.getDiskWriter( - blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics, true); // Write the number of records writeIntToBuffer(numRecordsToWrite, 0); writer.write(writeBuffer, 0, 4); diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b9d83495d29b6..8bbb3b55d618c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -49,7 +49,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.getBoolean("spark.shuffle.checksum", true)) // Wrap the streams for compression and encryption based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b83324e0fc..852c7b2742dc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -744,11 +744,13 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics, + checksum: Boolean): DiskBlockObjectWriter = { val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, - syncWrites, writeMetrics, blockId) + syncWrites, writeMetrics, blockId, + checksum && conf.getBoolean("spark.shuffle.checksum", true)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827ae1598..582ad7a04ed93 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel +import java.util.zip.{Adler32, CRC32} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging @@ -44,7 +45,8 @@ private[spark] class DiskBlockObjectWriter( // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. writeMetrics: ShuffleWriteMetrics, - val blockId: BlockId = null) + val blockId: BlockId = null, + val checksum: Boolean = false) extends OutputStream with Logging { @@ -116,7 +118,12 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + val withChecksum = if (checksum) { + new ChecksumOutputStream(mcs, new Adler32()) + } else { + mcs + } + bs = wrapStream(withChecksum) objOut = serializerInstance.serializeStream(bs) streamOpen = true this diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4dc2f362329a0..c36d2fc23d89b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,10 +17,11 @@ package org.apache.spark.storage -import java.io.InputStream +import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.control.NonFatal @@ -57,7 +58,8 @@ final class ShuffleBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], maxBytesInFlight: Long, - maxReqsInFlight: Int) + maxReqsInFlight: Int, + checksum: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -124,7 +126,9 @@ final class ShuffleBlockFetcherIterator( private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, _, buf, _) => buf.release() + case SuccessFetchResult(_, _, _, buf, in, _) => + buf.release() + in.close() case _ => } currentResult = null @@ -143,17 +147,20 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => + case SuccessFetchResult(_, address, _, buf, in, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() + in.close() case _ => } } } + val corruptedBlocks = mutable.HashSet[String]() + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -161,6 +168,7 @@ final class ShuffleBlockFetcherIterator( reqsInFlight += 1 // so we can look up the size of each blockID + val blocks = req.blocks val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys val blockIds = req.blocks.map(_._1.toString) @@ -171,6 +179,22 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. + val in = try { + buf.createInputStream(checksum) + } catch { + case e: IOException => + ShuffleBlockFetcherIterator.this.synchronized { + if (corruptedBlocks.contains(blockId)) { + onBlockFetchFailure(blockId, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again") + val index = blockIds.indexOf(blockId) + fetchRequests += new FetchRequest(address, blocks.slice(index, index + 1)) + } + } + return + } + assert(in != null, s"buf is $buf") ShuffleBlockFetcherIterator.this.synchronized { if (!isZombie) { // Increment the ref count because we need to pass this to a different thread. @@ -178,7 +202,7 @@ final class ShuffleBlockFetcherIterator( buf.retain() remainingBlocks -= blockId results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) + in, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) } } @@ -258,8 +282,9 @@ final class ShuffleBlockFetcherIterator( val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) + val in = buf.createInputStream(checksum) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, in, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -312,7 +337,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => + case SuccessFetchResult(_, address, size, buf, _, isNetworkReqDone) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) @@ -331,9 +356,9 @@ final class ShuffleBlockFetcherIterator( case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) - case SuccessFetchResult(blockId, address, _, buf, _) => + case SuccessFetchResult(blockId, address, _, _, in, _) => try { - (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) + (result.blockId, new BufferReleasingInputStream(in, this)) } catch { case NonFatal(t) => throwFetchFailedException(blockId, address, t) @@ -375,7 +400,6 @@ private class BufferReleasingInputStream( override def close(): Unit = { if (!closed) { - delegate.close() iterator.releaseCurrentResultBuffer() closed = true } @@ -431,6 +455,7 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, size: Long, buf: ManagedBuffer, + in: InputStream, isNetworkReqDone: Boolean) extends FetchResult { require(buf != null) require(size >= 0) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 948cc3b099b18..e08db4b97c015 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -206,7 +206,7 @@ class ExternalAppendOnlyMap[K, V, C]( private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)]) : DiskMapIterator = { val (blockId, file) = diskBlockManager.createTempLocalBlock() - val writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics) + val writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics, false) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..c68c364618870 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -273,7 +273,7 @@ private[spark] class ExternalSorter[K, V, C]( var objectsWritten: Long = 0 val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics val writer: DiskBlockObjectWriter = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics, false) // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] @@ -687,7 +687,7 @@ private[spark] class ExternalSorter[K, V, C]( // Track location of each range in the output file val lengths = new Array[Long](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, - context.taskMetrics().shuffleWriteMetrics) + context.taskMetrics().shuffleWriteMetrics, true) if (spills.isEmpty) { // Case where we only have in-memory data diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a96cd82382e2c..35defc73b7408 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -127,11 +127,11 @@ public void setUp() throws IOException { any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + any(ShuffleWriteMetrics.class), + anyBoolean())).thenAnswer(new Answer() { @Override public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { Object[] args = invocationOnMock.getArguments(); - return new DiskBlockObjectWriter( (File) args[1], (SerializerInstance) args[2], @@ -139,7 +139,8 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th new WrapStream(), false, (ShuffleWriteMetrics) args[4], - (BlockId) args[0] + (BlockId) args[0], + (boolean) args[5] && conf.getBoolean("spark.shuffle.checksum", true) ); } }); @@ -203,7 +204,7 @@ private List> readRecordsFromFile() throws IOException { if (partitionSize > 0) { InputStream in = new FileInputStream(mergedOutputFile); ByteStreams.skipFully(in, startOffset); - in = new LimitedInputStream(in, partitionSize); + in = new LimitedInputStream(in, partitionSize - 8); // ignore checksum if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 33709b454c4c9..996f15b681564 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -54,6 +54,7 @@ import static org.junit.Assert.assertFalse; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.when; @@ -113,7 +114,8 @@ public Tuple2 answer(InvocationOnMock invocationOnMock) any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + any(ShuffleWriteMetrics.class), + anyBoolean())).thenAnswer(new Answer() { @Override public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { Object[] args = invocationOnMock.getArguments(); @@ -125,7 +127,8 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th new WrapStream(), false, (ShuffleWriteMetrics) args[4], - (BlockId) args[0] + (BlockId) args[0], + (boolean) args[5] ); } }); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a9cf8ff520ed4..6b770cf80e9bf 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -119,7 +119,8 @@ public Tuple2 answer(InvocationOnMock invocationOnMock) any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + any(ShuffleWriteMetrics.class), + anyBoolean())).thenAnswer(new Answer() { @Override public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { Object[] args = invocationOnMock.getArguments(); @@ -131,7 +132,8 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th new WrapStream(), false, (ShuffleWriteMetrics) args[4], - (BlockId) args[0] + (BlockId) args[0], + (boolean) args[5] ); } }); diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 022fe91edade9..1f3d26bf2c767 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -120,7 +120,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => val actualString = CharStreams.toString( - new InputStreamReader(buf.createInputStream(), StandardCharsets.UTF_8)) + new InputStreamReader(buf.createInputStream(false), StandardCharsets.UTF_8)) actualString should equal(blockString) buf.release() Success(()) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5fdbd..4d4ba81f05c0a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -39,7 +39,8 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed override def size(): Long = underlyingBuffer.size() override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() - override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def createInputStream(checksum: Boolean): InputStream = + underlyingBuffer.createInputStream(false) override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() override def retain(): ManagedBuffer = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 442941685f1ae..f02e524f226be 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -86,7 +86,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[File], any[SerializerInstance], anyInt(), - any[ShuffleWriteMetrics] + any[ShuffleWriteMetrics], + anyBoolean() )).thenAnswer(new Answer[DiskBlockObjectWriter] { override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments @@ -97,7 +98,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId] + blockId = args(0).asInstanceOf[BlockId], + args(5).asInstanceOf[Boolean] ) } }) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e3ec99685f73c..4d898fb65fd76 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -23,7 +23,7 @@ import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future -import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Matchers.{any, anyBoolean, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -63,7 +63,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + when(mockManagedBuffer.createInputStream(anyBoolean())).thenReturn(mock(classOf[InputStream])) mockManagedBuffer } @@ -100,7 +100,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, blocksByAddress, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + false) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -173,7 +174,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, blocksByAddress, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + false) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -201,9 +203,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() ) // Semaphore to coordinate event sequence in two different threads. @@ -236,7 +238,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, blocksByAddress, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire()