Skip to content

Commit bda37bb

Browse files
committed
Implement Mridul's ExternalAppendOnlyMap fixes in ExternalSorter too
Modified ExternalSorterSuite to also set a low object stream reset and batch size, and verified that it failed before the changes and succeeded after.
1 parent 0d6dad7 commit bda37bb

File tree

3 files changed

+120
-48
lines changed

3 files changed

+120
-48
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ class ExternalAppendOnlyMap[K, V, C](
404404
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
405405
*/
406406
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
407-
extends Iterator[(K, C)] {
407+
extends Iterator[(K, C)]
408+
{
408409
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
409410
assert(file.length() == batchOffsets(batchOffsets.length - 1))
410411

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.collection.mutable
2626
import com.google.common.io.ByteStreams
2727

2828
import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
29-
import org.apache.spark.serializer.Serializer
29+
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3030
import org.apache.spark.storage.BlockId
3131

3232
/**
@@ -273,13 +273,16 @@ private[spark] class ExternalSorter[K, V, C](
273273
// Flush the disk writer's contents to disk, and update relevant variables.
274274
// The writer is closed at the end of this process, and cannot be reused.
275275
def flush() = {
276-
writer.commitAndClose()
277-
val bytesWritten = writer.bytesWritten
276+
val w = writer
277+
writer = null
278+
w.commitAndClose()
279+
val bytesWritten = w.bytesWritten
278280
batchSizes.append(bytesWritten)
279281
_diskBytesSpilled += bytesWritten
280282
objectsWritten = 0
281283
}
282284

285+
var success = false
283286
try {
284287
val it = collection.destructiveSortedIterator(partitionKeyComparator)
285288
while (it.hasNext) {
@@ -299,13 +302,23 @@ private[spark] class ExternalSorter[K, V, C](
299302
}
300303
if (objectsWritten > 0) {
301304
flush()
305+
} else if (writer != null) {
306+
val w = writer
307+
writer = null
308+
w.revertPartialWritesAndClose()
309+
}
310+
success = true
311+
} finally {
312+
if (!success) {
313+
// This code path only happens if an exception was thrown above before we set success;
314+
// close our stuff and let the exception be thrown further
315+
if (writer != null) {
316+
writer.revertPartialWritesAndClose()
317+
}
318+
if (file.exists()) {
319+
file.delete()
320+
}
302321
}
303-
writer.close()
304-
} catch {
305-
case e: Exception =>
306-
writer.close()
307-
file.delete()
308-
throw e
309322
}
310323

311324
if (usingMap) {
@@ -472,36 +485,58 @@ private[spark] class ExternalSorter[K, V, C](
472485
* partitions to be requested in order.
473486
*/
474487
private[this] class SpillReader(spill: SpilledFile) {
475-
val fileStream = new FileInputStream(spill.file)
476-
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
488+
// Serializer batch offsets; size will be batchSize.length + 1
489+
val batchOffsets = spill.serializerBatchSizes.scanLeft(0L)(_ + _)
477490

478491
// Track which partition and which batch stream we're in. These will be the indices of
479492
// the next element we will read. We'll also store the last partition read so that
480493
// readNextPartition() can figure out what partition that was from.
481494
var partitionId = 0
482495
var indexInPartition = 0L
483-
var batchStreamsRead = 0
496+
var batchId = 0
484497
var indexInBatch = 0
485498
var lastPartitionId = 0
486499

487500
skipToNextPartition()
488501

489-
// An intermediate stream that reads from exactly one batch
502+
503+
// Intermediate file and deserializer streams that read from exactly one batch
490504
// This guards against pre-fetching and other arbitrary behavior of higher level streams
491-
var batchStream = nextBatchStream()
492-
var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
493-
var deserStream = serInstance.deserializeStream(compressedStream)
505+
var fileStream: FileInputStream = null
506+
var deserializeStream = nextBatchStream() // Also sets fileStream
507+
494508
var nextItem: (K, C) = null
495509
var finished = false
496510

497511
/** Construct a stream that only reads from the next batch */
498-
def nextBatchStream(): InputStream = {
499-
if (batchStreamsRead < spill.serializerBatchSizes.length) {
500-
batchStreamsRead += 1
501-
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
512+
def nextBatchStream(): DeserializationStream = {
513+
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
514+
// we're still in a valid batch.
515+
if (batchId < batchOffsets.length - 1) {
516+
if (deserializeStream != null) {
517+
deserializeStream.close()
518+
fileStream.close()
519+
deserializeStream = null
520+
fileStream = null
521+
}
522+
523+
val start = batchOffsets(batchId)
524+
fileStream = new FileInputStream(spill.file)
525+
fileStream.getChannel.position(start)
526+
batchId += 1
527+
528+
val end = batchOffsets(batchId)
529+
530+
assert(end >= start, "start = " + start + ", end = " + end +
531+
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
532+
533+
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
534+
val compressedStream = blockManager.wrapForCompression(spill.blockId, bufferedStream)
535+
serInstance.deserializeStream(compressedStream)
502536
} else {
503-
// No more batches left; give an empty stream
504-
bufferedStream
537+
// No more batches left
538+
cleanup()
539+
null
505540
}
506541
}
507542

@@ -525,27 +560,27 @@ private[spark] class ExternalSorter[K, V, C](
525560
* If no more pairs are left, return null.
526561
*/
527562
private def readNextItem(): (K, C) = {
528-
if (finished) {
563+
if (finished || deserializeStream == null) {
529564
return null
530565
}
531-
val k = deserStream.readObject().asInstanceOf[K]
532-
val c = deserStream.readObject().asInstanceOf[C]
566+
val k = deserializeStream.readObject().asInstanceOf[K]
567+
val c = deserializeStream.readObject().asInstanceOf[C]
533568
lastPartitionId = partitionId
534569
// Start reading the next batch if we're done with this one
535570
indexInBatch += 1
536571
if (indexInBatch == serializerBatchSize) {
537-
batchStream = nextBatchStream()
538-
compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
539-
deserStream = serInstance.deserializeStream(compressedStream)
540572
indexInBatch = 0
573+
deserializeStream = nextBatchStream()
541574
}
542575
// Update the partition location of the element we're reading
543576
indexInPartition += 1
544577
skipToNextPartition()
545578
// If we've finished reading the last partition, remember that we're done
546579
if (partitionId == numPartitions) {
547580
finished = true
548-
deserStream.close()
581+
if (deserializeStream != null) {
582+
deserializeStream.close()
583+
}
549584
}
550585
(k, c)
551586
}
@@ -578,6 +613,31 @@ private[spark] class ExternalSorter[K, V, C](
578613
item
579614
}
580615
}
616+
617+
// Clean up our open streams and put us in a state where we can't read any more data
618+
def cleanup() {
619+
batchId = batchOffsets.length // Prevent reading any other batch
620+
val ds = deserializeStream
621+
val fs = fileStream
622+
deserializeStream = null
623+
fileStream = null
624+
625+
if (ds != null) {
626+
try {
627+
ds.close()
628+
} catch {
629+
case e: IOException =>
630+
// Make sure we at least close the file handle
631+
if (fs != null) {
632+
try { fs.close() } catch { case e2: IOException => }
633+
}
634+
throw e
635+
}
636+
}
637+
638+
// NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop().
639+
// This should also be fixed in ExternalAppendOnlyMap.
640+
}
581641
}
582642

583643
/**

core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ import org.apache.spark._
2525
import org.apache.spark.SparkContext._
2626

2727
class ExternalSorterSuite extends FunSuite with LocalSparkContext {
28+
private def createSparkConf(loadDefaults: Boolean): SparkConf = {
29+
val conf = new SparkConf(loadDefaults)
30+
// Make the Java serializer write a reset instruction (TC_RESET) after each object to test
31+
// for a bug we had with bytes written past the last object in a batch (SPARK-2792)
32+
conf.set("spark.serializer.objectStreamReset", "0")
33+
conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
34+
// Ensure that we actually have multiple batches per spill file
35+
conf.set("spark.shuffle.spill.batchSize", "10")
36+
conf
37+
}
38+
2839
test("empty data stream") {
2940
val conf = new SparkConf(false)
3041
conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -60,7 +71,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
6071
}
6172

6273
test("few elements per partition") {
63-
val conf = new SparkConf(false)
74+
val conf = createSparkConf(false)
6475
conf.set("spark.shuffle.memoryFraction", "0.001")
6576
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
6677
sc = new SparkContext("local", "test", conf)
@@ -102,7 +113,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
102113
}
103114

104115
test("empty partitions with spilling") {
105-
val conf = new SparkConf(false)
116+
val conf = createSparkConf(false)
106117
conf.set("spark.shuffle.memoryFraction", "0.001")
107118
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
108119
sc = new SparkContext("local", "test", conf)
@@ -127,7 +138,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
127138
}
128139

129140
test("spilling in local cluster") {
130-
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
141+
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
131142
conf.set("spark.shuffle.memoryFraction", "0.001")
132143
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
133144
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -198,7 +209,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
198209
}
199210

200211
test("spilling in local cluster with many reduce tasks") {
201-
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
212+
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
202213
conf.set("spark.shuffle.memoryFraction", "0.001")
203214
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
204215
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
@@ -269,7 +280,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
269280
}
270281

271282
test("cleanup of intermediate files in sorter") {
272-
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
283+
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
273284
conf.set("spark.shuffle.memoryFraction", "0.001")
274285
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
275286
sc = new SparkContext("local", "test", conf)
@@ -290,7 +301,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
290301
}
291302

292303
test("cleanup of intermediate files in sorter if there are errors") {
293-
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
304+
val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
294305
conf.set("spark.shuffle.memoryFraction", "0.001")
295306
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
296307
sc = new SparkContext("local", "test", conf)
@@ -311,7 +322,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
311322
}
312323

313324
test("cleanup of intermediate files in shuffle") {
314-
val conf = new SparkConf(false)
325+
val conf = createSparkConf(false)
315326
conf.set("spark.shuffle.memoryFraction", "0.001")
316327
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
317328
sc = new SparkContext("local", "test", conf)
@@ -326,7 +337,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
326337
}
327338

328339
test("cleanup of intermediate files in shuffle with errors") {
329-
val conf = new SparkConf(false)
340+
val conf = createSparkConf(false)
330341
conf.set("spark.shuffle.memoryFraction", "0.001")
331342
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
332343
sc = new SparkContext("local", "test", conf)
@@ -348,7 +359,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
348359
}
349360

350361
test("no partial aggregation or sorting") {
351-
val conf = new SparkConf(false)
362+
val conf = createSparkConf(false)
352363
conf.set("spark.shuffle.memoryFraction", "0.001")
353364
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
354365
sc = new SparkContext("local", "test", conf)
@@ -363,7 +374,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
363374
}
364375

365376
test("partial aggregation without spill") {
366-
val conf = new SparkConf(false)
377+
val conf = createSparkConf(false)
367378
conf.set("spark.shuffle.memoryFraction", "0.001")
368379
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
369380
sc = new SparkContext("local", "test", conf)
@@ -379,7 +390,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
379390
}
380391

381392
test("partial aggregation with spill, no ordering") {
382-
val conf = new SparkConf(false)
393+
val conf = createSparkConf(false)
383394
conf.set("spark.shuffle.memoryFraction", "0.001")
384395
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
385396
sc = new SparkContext("local", "test", conf)
@@ -395,7 +406,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
395406
}
396407

397408
test("partial aggregation with spill, with ordering") {
398-
val conf = new SparkConf(false)
409+
val conf = createSparkConf(false)
399410
conf.set("spark.shuffle.memoryFraction", "0.001")
400411
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
401412
sc = new SparkContext("local", "test", conf)
@@ -412,7 +423,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
412423
}
413424

414425
test("sorting without aggregation, no spill") {
415-
val conf = new SparkConf(false)
426+
val conf = createSparkConf(false)
416427
conf.set("spark.shuffle.memoryFraction", "0.001")
417428
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
418429
sc = new SparkContext("local", "test", conf)
@@ -429,7 +440,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
429440
}
430441

431442
test("sorting without aggregation, with spill") {
432-
val conf = new SparkConf(false)
443+
val conf = createSparkConf(false)
433444
conf.set("spark.shuffle.memoryFraction", "0.001")
434445
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
435446
sc = new SparkContext("local", "test", conf)
@@ -446,7 +457,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
446457
}
447458

448459
test("spilling with hash collisions") {
449-
val conf = new SparkConf(true)
460+
val conf = createSparkConf(true)
450461
conf.set("spark.shuffle.memoryFraction", "0.001")
451462
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
452463

@@ -503,7 +514,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
503514
}
504515

505516
test("spilling with many hash collisions") {
506-
val conf = new SparkConf(true)
517+
val conf = createSparkConf(true)
507518
conf.set("spark.shuffle.memoryFraction", "0.0001")
508519
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
509520

@@ -526,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
526537
}
527538

528539
test("spilling with hash collisions using the Int.MaxValue key") {
529-
val conf = new SparkConf(true)
540+
val conf = createSparkConf(true)
530541
conf.set("spark.shuffle.memoryFraction", "0.001")
531542
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
532543

@@ -547,7 +558,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
547558
}
548559

549560
test("spilling with null keys and values") {
550-
val conf = new SparkConf(true)
561+
val conf = createSparkConf(true)
551562
conf.set("spark.shuffle.memoryFraction", "0.001")
552563
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
553564

0 commit comments

Comments
 (0)