2525import java .nio .channels .FileChannel ;
2626import java .util .Iterator ;
2727
28- import org .apache .spark .shuffle .ShuffleMemoryManager ;
2928import scala .Option ;
3029import scala .Product2 ;
3130import scala .collection .JavaConversions ;
3231import scala .reflect .ClassTag ;
3332import scala .reflect .ClassTag$ ;
3433
3534import com .esotericsoftware .kryo .io .ByteBufferOutputStream ;
35+ import com .google .common .io .ByteStreams ;
36+ import com .google .common .io .Files ;
37+ import org .slf4j .Logger ;
38+ import org .slf4j .LoggerFactory ;
3639
3740import org .apache .spark .*;
3841import org .apache .spark .executor .ShuffleWriteMetrics ;
42+ import org .apache .spark .network .util .LimitedInputStream ;
3943import org .apache .spark .scheduler .MapStatus ;
4044import org .apache .spark .scheduler .MapStatus$ ;
4145import org .apache .spark .serializer .SerializationStream ;
4246import org .apache .spark .serializer .Serializer ;
4347import org .apache .spark .serializer .SerializerInstance ;
4448import org .apache .spark .shuffle .IndexShuffleBlockManager ;
49+ import org .apache .spark .shuffle .ShuffleMemoryManager ;
4550import org .apache .spark .shuffle .ShuffleWriter ;
4651import org .apache .spark .storage .BlockManager ;
4752import org .apache .spark .unsafe .PlatformDependent ;
4853import org .apache .spark .unsafe .memory .TaskMemoryManager ;
4954
5055public class UnsafeShuffleWriter <K , V > extends ShuffleWriter <K , V > {
5156
57+ private final Logger logger = LoggerFactory .getLogger (UnsafeShuffleWriter .class );
58+
5259 private static final int SER_BUFFER_SIZE = 1024 * 1024 ; // TODO: tune this
5360 private static final ClassTag <Object > OBJECT_CLASS_TAG = ClassTag$ .MODULE$ .Object ();
5461
@@ -63,6 +70,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
6370 private final int mapId ;
6471 private final TaskContext taskContext ;
6572 private final SparkConf sparkConf ;
73+ private final boolean transferToEnabled ;
6674
6775 private MapStatus mapStatus = null ;
6876
@@ -95,6 +103,7 @@ public UnsafeShuffleWriter(
95103 taskContext .taskMetrics ().shuffleWriteMetrics_$eq (Option .apply (writeMetrics ));
96104 this .taskContext = taskContext ;
97105 this .sparkConf = sparkConf ;
106+ this .transferToEnabled = sparkConf .getBoolean ("spark.file.transferTo" , true );
98107 }
99108
100109 public void write (Iterator <Product2 <K , V >> records ) {
@@ -116,6 +125,10 @@ private void freeMemory() {
116125 // TODO
117126 }
118127
128+ private void deleteSpills () {
129+ // TODO
130+ }
131+
119132 private SpillInfo [] insertRecordsIntoSorter (
120133 scala .collection .Iterator <? extends Product2 <K , V >> records ) throws Exception {
121134 final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter (
@@ -154,55 +167,127 @@ private SpillInfo[] insertRecordsIntoSorter(
154167
155168 private long [] mergeSpills (SpillInfo [] spills ) throws IOException {
156169 final File outputFile = shuffleBlockManager .getDataFile (shuffleId , mapId );
170+ try {
171+ if (spills .length == 0 ) {
172+ new FileOutputStream (outputFile ).close (); // Create an empty file
173+ return new long [partitioner .numPartitions ()];
174+ } else if (spills .length == 1 ) {
175+ // Note: we'll have to watch out for corner-cases in this code path when working on shuffle
176+ // metrics integration, since any metrics updates that are performed during the merge will
177+ // also have to be done here. In this branch, the shuffle technically didn't need to spill
178+ // because we're only trying to merge one file, so we may need to ensure that metrics that
179+ // would otherwise be counted as spill metrics are actually counted as regular write
180+ // metrics.
181+ Files .move (spills [0 ].file , outputFile );
182+ return spills [0 ].partitionLengths ;
183+ } else {
184+ // Need to merge multiple spills.
185+ if (transferToEnabled ) {
186+ return mergeSpillsWithTransferTo (spills , outputFile );
187+ } else {
188+ return mergeSpillsWithFileStream (spills , outputFile );
189+ }
190+ }
191+ } catch (IOException e ) {
192+ if (outputFile .exists () && !outputFile .delete ()) {
193+ logger .error ("Unable to delete output file {}" , outputFile .getPath ());
194+ }
195+ throw e ;
196+ }
197+ }
198+
199+ private long [] mergeSpillsWithFileStream (SpillInfo [] spills , File outputFile ) throws IOException {
157200 final int numPartitions = partitioner .numPartitions ();
158201 final long [] partitionLengths = new long [numPartitions ];
202+ final FileInputStream [] spillInputStreams = new FileInputStream [spills .length ];
203+ FileOutputStream mergedFileOutputStream = null ;
204+
205+ try {
206+ for (int i = 0 ; i < spills .length ; i ++) {
207+ spillInputStreams [i ] = new FileInputStream (spills [i ].file );
208+ }
209+ mergedFileOutputStream = new FileOutputStream (outputFile );
159210
160- if (spills .length == 0 ) {
161- new FileOutputStream (outputFile ).close ();
162- return partitionLengths ;
211+ for (int partition = 0 ; partition < numPartitions ; partition ++) {
212+ for (int i = 0 ; i < spills .length ; i ++) {
213+ final long partitionLengthInSpill = spills [i ].partitionLengths [partition ];
214+ final FileInputStream spillInputStream = spillInputStreams [i ];
215+ ByteStreams .copy
216+ (new LimitedInputStream (spillInputStream , partitionLengthInSpill ),
217+ mergedFileOutputStream );
218+ partitionLengths [partition ] += partitionLengthInSpill ;
219+ }
220+ }
221+ } finally {
222+ for (int i = 0 ; i < spills .length ; i ++) {
223+ if (spillInputStreams [i ] != null ) {
224+ spillInputStreams [i ].close ();
225+ }
226+ }
227+ if (mergedFileOutputStream != null ) {
228+ mergedFileOutputStream .close ();
229+ }
163230 }
231+ return partitionLengths ;
232+ }
164233
234+ private long [] mergeSpillsWithTransferTo (SpillInfo [] spills , File outputFile ) throws IOException {
235+ final int numPartitions = partitioner .numPartitions ();
236+ final long [] partitionLengths = new long [numPartitions ];
165237 final FileChannel [] spillInputChannels = new FileChannel [spills .length ];
166238 final long [] spillInputChannelPositions = new long [spills .length ];
239+ FileChannel mergedFileOutputChannel = null ;
167240
168- // TODO: We need to add an option to bypass transferTo here since older Linux kernels are
169- // affected by a bug here that can lead to data truncation; see the comments Utils.scala,
170- // in the copyStream() method. I didn't use copyStream() here because we only want to copy
171- // a limited number of bytes from the stream and I didn't want to modify / extend that method
172- // to accept a length.
173-
174- // TODO: special case optimization for case where we only write one file (non-spill case).
175-
176- for (int i = 0 ; i < spills .length ; i ++) {
177- spillInputChannels [i ] = new FileInputStream (spills [i ].file ).getChannel ();
178- }
179-
180- final FileChannel mergedFileOutputChannel = new FileOutputStream (outputFile ).getChannel ();
181-
182- for (int partition = 0 ; partition < numPartitions ; partition ++) {
241+ try {
183242 for (int i = 0 ; i < spills .length ; i ++) {
184- final long partitionLengthInSpill = spills [i ].partitionLengths [partition ];
185- long bytesToTransfer = partitionLengthInSpill ;
186- final FileChannel spillInputChannel = spillInputChannels [i ];
187- while (bytesToTransfer > 0 ) {
188- final long actualBytesTransferred = spillInputChannel .transferTo (
243+ spillInputChannels [i ] = new FileInputStream (spills [i ].file ).getChannel ();
244+ }
245+ // This file needs to opened in append mode in order to work around a Linux kernel bug that
246+ // affects transferTo; see SPARK-3948 for more details.
247+ mergedFileOutputChannel = new FileOutputStream (outputFile , true ).getChannel ();
248+
249+ long bytesWrittenToMergedFile = 0 ;
250+ for (int partition = 0 ; partition < numPartitions ; partition ++) {
251+ for (int i = 0 ; i < spills .length ; i ++) {
252+ final long partitionLengthInSpill = spills [i ].partitionLengths [partition ];
253+ long bytesToTransfer = partitionLengthInSpill ;
254+ final FileChannel spillInputChannel = spillInputChannels [i ];
255+ while (bytesToTransfer > 0 ) {
256+ final long actualBytesTransferred = spillInputChannel .transferTo (
189257 spillInputChannelPositions [i ],
190258 bytesToTransfer ,
191259 mergedFileOutputChannel );
192- spillInputChannelPositions [i ] += actualBytesTransferred ;
193- bytesToTransfer -= actualBytesTransferred ;
260+ spillInputChannelPositions [i ] += actualBytesTransferred ;
261+ bytesToTransfer -= actualBytesTransferred ;
262+ }
263+ bytesWrittenToMergedFile += partitionLengthInSpill ;
264+ partitionLengths [partition ] += partitionLengthInSpill ;
194265 }
195- partitionLengths [partition ] += partitionLengthInSpill ;
266+ }
267+ // Check the position after transferTo loop to see if it is in the right position and raise an
268+ // exception if it is incorrect. The position will not be increased to the expected length
269+ // after calling transferTo in kernel version 2.6.32. This issue is described at
270+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
271+ if (mergedFileOutputChannel .position () != bytesWrittenToMergedFile ) {
272+ throw new IOException (
273+ "Current position " + mergedFileOutputChannel .position () + " does not equal expected " +
274+ "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
275+ " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
276+ "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
277+ "to disable this NIO feature."
278+ );
279+ }
280+ } finally {
281+ for (int i = 0 ; i < spills .length ; i ++) {
282+ assert (spillInputChannelPositions [i ] == spills [i ].file .length ());
283+ if (spillInputChannels [i ] != null ) {
284+ spillInputChannels [i ].close ();
285+ }
286+ }
287+ if (mergedFileOutputChannel != null ) {
288+ mergedFileOutputChannel .close ();
196289 }
197290 }
198-
199- // TODO: should this be in a finally block?
200- for (int i = 0 ; i < spills .length ; i ++) {
201- assert (spillInputChannelPositions [i ] == spills [i ].file .length ());
202- spillInputChannels [i ].close ();
203- }
204- mergedFileOutputChannel .close ();
205-
206291 return partitionLengths ;
207292 }
208293
@@ -215,6 +300,9 @@ public Option<MapStatus> stop(boolean success) {
215300 stopping = true ;
216301 freeMemory ();
217302 if (success ) {
303+ if (mapStatus == null ) {
304+ throw new IllegalStateException ("Cannot call stop(true) without having called write()" );
305+ }
218306 return Option .apply (mapStatus );
219307 } else {
220308 // The map task failed, so delete our output data.
0 commit comments