Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.network.buffer;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.zip.Adler32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to abstract out the checksum functions into some class so it's easier to change in the future.

import java.util.zip.CheckedInputStream;
import java.util.zip.Checksum;

import com.google.common.base.Objects;
import com.google.common.io.ByteStreams;
Expand Down Expand Up @@ -92,12 +91,27 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream createInputStream() throws IOException {
public InputStream createInputStream(boolean checksum) throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, is this only for testing? because it's otherwise very expensive to compute, requiring two passes over the file.

If so, should it not be in some more package-private method only and not exposed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question.

Actually we already have checksum along with compression, we could move the decompression a little bit earlier to detect the corruption in block fetcher, will try that soon.

FileInputStream is = null;
try {
is = new FileInputStream(file);
ByteStreams.skipFully(is, offset);
return new LimitedInputStream(is, length);
if (checksum) {
Checksum ck = new Adler32();
DataInputStream din = new DataInputStream(new CheckedInputStream(is, ck));
ByteStreams.skipFully(din, length - 8);
long sum = ck.getValue();
long expected = din.readLong();
if (sum != expected) {
throw new IOException("Checksum does not match " + sum + "!=" + expected);
}
is.close();
is = new FileInputStream(file);
ByteStreams.skipFully(is, offset);
return new LimitedInputStream(is, length - 8);
} else {
return new LimitedInputStream(is, length);
}
} catch (IOException e) {
try {
if (is != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public abstract class ManagedBuffer {
* necessarily check for the length of bytes read, so the caller is responsible for making sure
* it does not go over the limit.
*/
public abstract InputStream createInputStream() throws IOException;
public abstract InputStream createInputStream(boolean checksum) throws IOException;

/**
* Increment the reference count by one if applicable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.zip.Adler32;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
Expand All @@ -46,7 +47,21 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream createInputStream() throws IOException {
public InputStream createInputStream(boolean checksum) throws IOException {
if (checksum) {
Adler32 adler = new Adler32();
long size = size();
buf.markReaderIndex();
for (int i = 0; i < size - 8; i++) {
adler.update(buf.readByte());
}
long sum = buf.readLong();
if (adler.getValue() != sum) {
throw new IOException("Checksum does not match " + adler.getValue() + "!=" + sum);
}
buf.resetReaderIndex();
buf.writerIndex(buf.writerIndex() - 8);
}
return new ByteBufInputStream(buf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.zip.Adler32;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBufInputStream;
Expand All @@ -46,7 +47,23 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream createInputStream() throws IOException {
public InputStream createInputStream(boolean checksum) throws IOException {
if (checksum) {
Adler32 adler = new Adler32();
int position = buf.position();
int limit = buf.limit() - 8;
buf.position(limit);
long sum = buf.getLong();
buf.position(position);
// simplify this after drop Java 7 support
for (int i=buf.position(); i<limit; i++) {
adler.update(buf.get(i));
}
if (sum != adler.getValue()) {
throw new IOException("Checksum does not match: " + adler.getValue() + "!=" + sum);
}
buf.limit(limit);
}
return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public ByteBuffer nioByteBuffer() throws IOException {
}

@Override
public InputStream createInputStream() throws IOException {
return underlying.createInputStream();
public InputStream createInputStream(boolean checksum) throws IOException {
return underlying.createInputStream(checksum);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,17 @@

package org.apache.spark.network.sasl;

import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

import javax.security.sasl.SaslException;
import java.io.File;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.sasl.SaslException;

import com.google.common.collect.Lists;
import com.google.common.io.ByteStreams;
Expand Down Expand Up @@ -62,6 +59,9 @@
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;

import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

/**
* Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
*/
Expand Down Expand Up @@ -296,7 +296,7 @@ public Void answer(InvocationOnMock invocation) {
verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
verify(callback, never()).onFailure(anyInt(), any(Throwable.class));

byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
byte[] received = ByteStreams.toByteArray(response.get().createInputStream(false));
assertTrue(Arrays.equals(data, received));
} finally {
file.delete();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.io.CharStreams;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;

import static org.junit.Assert.*;

public class ExternalShuffleBlockResolverSuite {
Expand Down Expand Up @@ -98,14 +99,14 @@ public void testSortShuffleBlocks() throws IOException {
dataContext.createExecutorInfo(SORT_MANAGER));

InputStream block0Stream =
resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream();
resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(false);
String block0 = CharStreams.toString(
new InputStreamReader(block0Stream, StandardCharsets.UTF_8));
block0Stream.close();
assertEquals(sortBlock0, block0);

InputStream block1Stream =
resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream();
resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(false);
String block1 = CharStreams.toString(
new InputStreamReader(block1Stream, StandardCharsets.UTF_8));
block1Stream.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics, true);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
final SerializerInstance ser = DummySerializerInstance.INSTANCE;

final DiskBlockObjectWriter writer =
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse,
// only generate checksum for only spill
isLastFile && spills.isEmpty());

int currentPartition = -1;
while (sortedRecords.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.*;
import java.nio.channels.FileChannel;
import java.util.Iterator;
import java.util.zip.Adler32;

import scala.Option;
import scala.Product2;
Expand All @@ -35,7 +36,10 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.*;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
Expand All @@ -49,6 +53,7 @@
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.ChecksumOutputStream;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.Utils;
Expand All @@ -75,6 +80,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final SparkConf sparkConf;
private final boolean transferToEnabled;
private final int initialSortBufferSize;
private final boolean checksum;

@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
Expand Down Expand Up @@ -108,8 +114,8 @@ public UnsafeShuffleWriter(
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
"UnsafeShuffleWriter can only be used for shuffles with at most " +
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
" reduce partitions");
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
" reduce partitions");
}
this.blockManager = blockManager;
this.shuffleBlockResolver = shuffleBlockResolver;
Expand All @@ -124,7 +130,9 @@ public UnsafeShuffleWriter(
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize",
DEFAULT_INITIAL_SORT_BUFFER_SIZE);
DEFAULT_INITIAL_SORT_BUFFER_SIZE);
this.checksum = sparkConf.getBoolean("spark.shuffle.checksum", true);

open();
}

Expand Down Expand Up @@ -289,7 +297,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled) {
if (transferToEnabled && !checksum) {
logger.debug("Using transferTo-based fast merge");
partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
} else {
Expand Down Expand Up @@ -346,8 +354,11 @@ private long[] mergeSpillsWithFileStream(
}
for (int partition = 0; partition < numPartitions; partition++) {
final long initialFileLength = outputFile.length();
mergedFileOutputStream =
new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
OutputStream fos = new FileOutputStream(outputFile, true);
if (checksum) {
fos = new ChecksumOutputStream(fos, new Adler32());
}
mergedFileOutputStream = new TimeTrackingOutputStream(writeMetrics, fos);
if (compressionCodec != null) {
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.storage;

import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.zip.Checksum;

/**
* A output stream that generate checksum for written data and write the checksum as long at
* the end of stream.
*/
public class ChecksumOutputStream extends FilterOutputStream {
private Checksum cksum;
private boolean closed;

public ChecksumOutputStream(OutputStream out, Checksum cksum) {
super(out);
cksum.reset();
this.cksum = cksum;
this.closed = false;
}

public void write(int b) throws IOException {
out.write(b);
cksum.update(b);
}

public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}

public void write(byte[] b, int off, int len) throws IOException {
out.write(b, off, len);
cksum.update(b, off, len);
}

public void close() throws IOException {
flush();
if (!closed) {
closed = true;
ByteBuffer buffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN);
buffer.putLong(cksum.getValue());
out.write(buffer.array());
out.close();
}
}
}
Loading