Skip to content

Commit 722849b

Browse files
committed
Add workaround for transferTo() bug in merging code; refactor tests.
1 parent 9883e30 commit 722849b

File tree

2 files changed

+225
-94
lines changed

2 files changed

+225
-94
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 123 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,37 @@
2525
import java.nio.channels.FileChannel;
2626
import java.util.Iterator;
2727

28-
import org.apache.spark.shuffle.ShuffleMemoryManager;
2928
import scala.Option;
3029
import scala.Product2;
3130
import scala.collection.JavaConversions;
3231
import scala.reflect.ClassTag;
3332
import scala.reflect.ClassTag$;
3433

3534
import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
35+
import com.google.common.io.ByteStreams;
36+
import com.google.common.io.Files;
37+
import org.slf4j.Logger;
38+
import org.slf4j.LoggerFactory;
3639

3740
import org.apache.spark.*;
3841
import org.apache.spark.executor.ShuffleWriteMetrics;
42+
import org.apache.spark.network.util.LimitedInputStream;
3943
import org.apache.spark.scheduler.MapStatus;
4044
import org.apache.spark.scheduler.MapStatus$;
4145
import org.apache.spark.serializer.SerializationStream;
4246
import org.apache.spark.serializer.Serializer;
4347
import org.apache.spark.serializer.SerializerInstance;
4448
import org.apache.spark.shuffle.IndexShuffleBlockManager;
49+
import org.apache.spark.shuffle.ShuffleMemoryManager;
4550
import org.apache.spark.shuffle.ShuffleWriter;
4651
import org.apache.spark.storage.BlockManager;
4752
import org.apache.spark.unsafe.PlatformDependent;
4853
import org.apache.spark.unsafe.memory.TaskMemoryManager;
4954

5055
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
5156

57+
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
58+
5259
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
5360
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
5461

@@ -63,6 +70,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
6370
private final int mapId;
6471
private final TaskContext taskContext;
6572
private final SparkConf sparkConf;
73+
private final boolean transferToEnabled;
6674

6775
private MapStatus mapStatus = null;
6876

@@ -95,6 +103,7 @@ public UnsafeShuffleWriter(
95103
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
96104
this.taskContext = taskContext;
97105
this.sparkConf = sparkConf;
106+
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
98107
}
99108

100109
public void write(Iterator<Product2<K, V>> records) {
@@ -116,6 +125,10 @@ private void freeMemory() {
116125
// TODO
117126
}
118127

128+
private void deleteSpills() {
129+
// TODO
130+
}
131+
119132
private SpillInfo[] insertRecordsIntoSorter(
120133
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
121134
final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter(
@@ -154,55 +167,127 @@ private SpillInfo[] insertRecordsIntoSorter(
154167

155168
private long[] mergeSpills(SpillInfo[] spills) throws IOException {
156169
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
170+
try {
171+
if (spills.length == 0) {
172+
new FileOutputStream(outputFile).close(); // Create an empty file
173+
return new long[partitioner.numPartitions()];
174+
} else if (spills.length == 1) {
175+
// Note: we'll have to watch out for corner-cases in this code path when working on shuffle
176+
// metrics integration, since any metrics updates that are performed during the merge will
177+
// also have to be done here. In this branch, the shuffle technically didn't need to spill
178+
// because we're only trying to merge one file, so we may need to ensure that metrics that
179+
// would otherwise be counted as spill metrics are actually counted as regular write
180+
// metrics.
181+
Files.move(spills[0].file, outputFile);
182+
return spills[0].partitionLengths;
183+
} else {
184+
// Need to merge multiple spills.
185+
if (transferToEnabled) {
186+
return mergeSpillsWithTransferTo(spills, outputFile);
187+
} else {
188+
return mergeSpillsWithFileStream(spills, outputFile);
189+
}
190+
}
191+
} catch (IOException e) {
192+
if (outputFile.exists() && !outputFile.delete()) {
193+
logger.error("Unable to delete output file {}", outputFile.getPath());
194+
}
195+
throw e;
196+
}
197+
}
198+
199+
private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException {
157200
final int numPartitions = partitioner.numPartitions();
158201
final long[] partitionLengths = new long[numPartitions];
202+
final FileInputStream[] spillInputStreams = new FileInputStream[spills.length];
203+
FileOutputStream mergedFileOutputStream = null;
204+
205+
try {
206+
for (int i = 0; i < spills.length; i++) {
207+
spillInputStreams[i] = new FileInputStream(spills[i].file);
208+
}
209+
mergedFileOutputStream = new FileOutputStream(outputFile);
159210

160-
if (spills.length == 0) {
161-
new FileOutputStream(outputFile).close();
162-
return partitionLengths;
211+
for (int partition = 0; partition < numPartitions; partition++) {
212+
for (int i = 0; i < spills.length; i++) {
213+
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
214+
final FileInputStream spillInputStream = spillInputStreams[i];
215+
ByteStreams.copy
216+
(new LimitedInputStream(spillInputStream, partitionLengthInSpill),
217+
mergedFileOutputStream);
218+
partitionLengths[partition] += partitionLengthInSpill;
219+
}
220+
}
221+
} finally {
222+
for (int i = 0; i < spills.length; i++) {
223+
if (spillInputStreams[i] != null) {
224+
spillInputStreams[i].close();
225+
}
226+
}
227+
if (mergedFileOutputStream != null) {
228+
mergedFileOutputStream.close();
229+
}
163230
}
231+
return partitionLengths;
232+
}
164233

234+
private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
235+
final int numPartitions = partitioner.numPartitions();
236+
final long[] partitionLengths = new long[numPartitions];
165237
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
166238
final long[] spillInputChannelPositions = new long[spills.length];
239+
FileChannel mergedFileOutputChannel = null;
167240

168-
// TODO: We need to add an option to bypass transferTo here since older Linux kernels are
169-
// affected by a bug here that can lead to data truncation; see the comments Utils.scala,
170-
// in the copyStream() method. I didn't use copyStream() here because we only want to copy
171-
// a limited number of bytes from the stream and I didn't want to modify / extend that method
172-
// to accept a length.
173-
174-
// TODO: special case optimization for case where we only write one file (non-spill case).
175-
176-
for (int i = 0; i < spills.length; i++) {
177-
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
178-
}
179-
180-
final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel();
181-
182-
for (int partition = 0; partition < numPartitions; partition++) {
241+
try {
183242
for (int i = 0; i < spills.length; i++) {
184-
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
185-
long bytesToTransfer = partitionLengthInSpill;
186-
final FileChannel spillInputChannel = spillInputChannels[i];
187-
while (bytesToTransfer > 0) {
188-
final long actualBytesTransferred = spillInputChannel.transferTo(
243+
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
244+
}
245+
// This file needs to opened in append mode in order to work around a Linux kernel bug that
246+
// affects transferTo; see SPARK-3948 for more details.
247+
mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
248+
249+
long bytesWrittenToMergedFile = 0;
250+
for (int partition = 0; partition < numPartitions; partition++) {
251+
for (int i = 0; i < spills.length; i++) {
252+
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
253+
long bytesToTransfer = partitionLengthInSpill;
254+
final FileChannel spillInputChannel = spillInputChannels[i];
255+
while (bytesToTransfer > 0) {
256+
final long actualBytesTransferred = spillInputChannel.transferTo(
189257
spillInputChannelPositions[i],
190258
bytesToTransfer,
191259
mergedFileOutputChannel);
192-
spillInputChannelPositions[i] += actualBytesTransferred;
193-
bytesToTransfer -= actualBytesTransferred;
260+
spillInputChannelPositions[i] += actualBytesTransferred;
261+
bytesToTransfer -= actualBytesTransferred;
262+
}
263+
bytesWrittenToMergedFile += partitionLengthInSpill;
264+
partitionLengths[partition] += partitionLengthInSpill;
194265
}
195-
partitionLengths[partition] += partitionLengthInSpill;
266+
}
267+
// Check the position after transferTo loop to see if it is in the right position and raise an
268+
// exception if it is incorrect. The position will not be increased to the expected length
269+
// after calling transferTo in kernel version 2.6.32. This issue is described at
270+
// https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
271+
if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
272+
throw new IOException(
273+
"Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
274+
"position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
275+
" version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
276+
"unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
277+
"to disable this NIO feature."
278+
);
279+
}
280+
} finally {
281+
for (int i = 0; i < spills.length; i++) {
282+
assert(spillInputChannelPositions[i] == spills[i].file.length());
283+
if (spillInputChannels[i] != null) {
284+
spillInputChannels[i].close();
285+
}
286+
}
287+
if (mergedFileOutputChannel != null) {
288+
mergedFileOutputChannel.close();
196289
}
197290
}
198-
199-
// TODO: should this be in a finally block?
200-
for (int i = 0; i < spills.length; i++) {
201-
assert(spillInputChannelPositions[i] == spills[i].file.length());
202-
spillInputChannels[i].close();
203-
}
204-
mergedFileOutputChannel.close();
205-
206291
return partitionLengths;
207292
}
208293

@@ -215,6 +300,9 @@ public Option<MapStatus> stop(boolean success) {
215300
stopping = true;
216301
freeMemory();
217302
if (success) {
303+
if (mapStatus == null) {
304+
throw new IllegalStateException("Cannot call stop(true) without having called write()");
305+
}
218306
return Option.apply(mapStatus);
219307
} else {
220308
// The map task failed, so delete our output data.

0 commit comments

Comments
 (0)