Skip to content

Commit ea90ea6

Browse files
mccheahMarcelo Vanzin
authored andcommitted
[SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin for the SortShuffleWriter
## What changes were proposed in this pull request? Use the shuffle writer APIs introduced in SPARK-28209 in the sort shuffle writer. ## How was this patch tested? Existing unit tests were changed to use the plugin instead, and they used the local disk version to ensure that there were no regressions. Closes #25342 from mccheah/shuffle-writer-refactor-sort-shuffle-writer. Lead-authored-by: mcheah <[email protected]> Co-authored-by: mccheah <[email protected]> Signed-off-by: Marcelo Vanzin <[email protected]>
1 parent 92cabf6 commit ea90ea6

File tree

8 files changed

+265
-31
lines changed

8 files changed

+265
-31
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle
19+
20+
import java.io.{Closeable, IOException, OutputStream}
21+
22+
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
23+
import org.apache.spark.shuffle.api.ShufflePartitionWriter
24+
import org.apache.spark.storage.BlockId
25+
import org.apache.spark.util.Utils
26+
import org.apache.spark.util.collection.PairsWriter
27+
28+
/**
29+
* A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
30+
* arbitrary partition writer instead of writing to local disk through the block manager.
31+
*/
32+
private[spark] class ShufflePartitionPairsWriter(
33+
partitionWriter: ShufflePartitionWriter,
34+
serializerManager: SerializerManager,
35+
serializerInstance: SerializerInstance,
36+
blockId: BlockId,
37+
writeMetrics: ShuffleWriteMetricsReporter)
38+
extends PairsWriter with Closeable {
39+
40+
private var isClosed = false
41+
private var partitionStream: OutputStream = _
42+
private var wrappedStream: OutputStream = _
43+
private var objOut: SerializationStream = _
44+
private var numRecordsWritten = 0
45+
private var curNumBytesWritten = 0L
46+
47+
override def write(key: Any, value: Any): Unit = {
48+
if (isClosed) {
49+
throw new IOException("Partition pairs writer is already closed.")
50+
}
51+
if (objOut == null) {
52+
open()
53+
}
54+
objOut.writeKey(key)
55+
objOut.writeValue(value)
56+
recordWritten()
57+
}
58+
59+
private def open(): Unit = {
60+
try {
61+
partitionStream = partitionWriter.openStream
62+
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
63+
objOut = serializerInstance.serializeStream(wrappedStream)
64+
} catch {
65+
case e: Exception =>
66+
Utils.tryLogNonFatalError {
67+
close()
68+
}
69+
throw e
70+
}
71+
}
72+
73+
override def close(): Unit = {
74+
if (!isClosed) {
75+
Utils.tryWithSafeFinally {
76+
Utils.tryWithSafeFinally {
77+
objOut = closeIfNonNull(objOut)
78+
// Setting these to null will prevent the underlying streams from being closed twice
79+
// just in case any stream's close() implementation is not idempotent.
80+
wrappedStream = null
81+
partitionStream = null
82+
} {
83+
// Normally closing objOut would close the inner streams as well, but just in case there
84+
// was an error in initialization etc. we make sure we clean the other streams up too.
85+
Utils.tryWithSafeFinally {
86+
wrappedStream = closeIfNonNull(wrappedStream)
87+
// Same as above - if wrappedStream closes then assume it closes underlying
88+
// partitionStream and don't close again in the finally
89+
partitionStream = null
90+
} {
91+
partitionStream = closeIfNonNull(partitionStream)
92+
}
93+
}
94+
updateBytesWritten()
95+
} {
96+
isClosed = true
97+
}
98+
}
99+
}
100+
101+
private def closeIfNonNull[T <: Closeable](closeable: T): T = {
102+
if (closeable != null) {
103+
closeable.close()
104+
}
105+
null.asInstanceOf[T]
106+
}
107+
108+
/**
109+
* Notify the writer that a record worth of bytes has been written with OutputStream#write.
110+
*/
111+
private def recordWritten(): Unit = {
112+
numRecordsWritten += 1
113+
writeMetrics.incRecordsWritten(1)
114+
115+
if (numRecordsWritten % 16384 == 0) {
116+
updateBytesWritten()
117+
}
118+
}
119+
120+
private def updateBytesWritten(): Unit = {
121+
val numBytesWritten = partitionWriter.getNumBytesWritten
122+
val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
123+
writeMetrics.incBytesWritten(bytesWrittenDiff)
124+
curNumBytesWritten = numBytesWritten
125+
}
126+
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
157157
metrics,
158158
shuffleExecutorComponents)
159159
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
160-
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
160+
new SortShuffleWriter(
161+
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
161162
}
162163
}
163164

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ import org.apache.spark._
2121
import org.apache.spark.internal.{config, Logging}
2222
import org.apache.spark.scheduler.MapStatus
2323
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
24-
import org.apache.spark.storage.ShuffleBlockId
25-
import org.apache.spark.util.Utils
24+
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
2625
import org.apache.spark.util.collection.ExternalSorter
2726

2827
private[spark] class SortShuffleWriter[K, V, C](
2928
shuffleBlockResolver: IndexShuffleBlockResolver,
3029
handle: BaseShuffleHandle[K, V, C],
3130
mapId: Int,
32-
context: TaskContext)
31+
context: TaskContext,
32+
shuffleExecutorComponents: ShuffleExecutorComponents)
3333
extends ShuffleWriter[K, V] with Logging {
3434

3535
private val dep = handle.dependency
@@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
6464
// Don't bother including the time to open the merged output file in the shuffle write time,
6565
// because it just opens a single file, so is typically too fast to measure accurately
6666
// (see SPARK-3570).
67-
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
68-
val tmp = Utils.tempFileWith(output)
69-
try {
70-
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
71-
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
72-
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
73-
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
74-
} finally {
75-
if (tmp.exists() && !tmp.delete()) {
76-
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
77-
}
78-
}
67+
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
68+
dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions)
69+
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
70+
val partitionLengths = mapOutputWriter.commitAllPartitions()
71+
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
7972
}
8073

8174
/** Close this writer, passing along whether the map completed */

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
2424
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
2525
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2626
import org.apache.spark.util.Utils
27+
import org.apache.spark.util.collection.PairsWriter
2728

2829
/**
2930
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
@@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
4647
writeMetrics: ShuffleWriteMetricsReporter,
4748
val blockId: BlockId = null)
4849
extends OutputStream
49-
with Logging {
50+
with Logging
51+
with PairsWriter {
5052

5153
/**
5254
* Guards against close calls, e.g. from a wrapping stream.
@@ -232,7 +234,7 @@ private[spark] class DiskBlockObjectWriter(
232234
/**
233235
* Writes a key-value pair.
234236
*/
235-
def write(key: Any, value: Any) {
237+
override def write(key: Any, value: Any) {
236238
if (!streamOpen) {
237239
open()
238240
}

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@ import java.util.Comparator
2323
import scala.collection.mutable
2424
import scala.collection.mutable.ArrayBuffer
2525

26-
import com.google.common.io.ByteStreams
26+
import com.google.common.io.{ByteStreams, Closeables}
2727

2828
import org.apache.spark._
2929
import org.apache.spark.executor.ShuffleWriteMetrics
3030
import org.apache.spark.internal.{config, Logging}
3131
import org.apache.spark.serializer._
32-
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
32+
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
33+
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
34+
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
35+
import org.apache.spark.util.{Utils => TryUtils}
3336

3437
/**
3538
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -670,11 +673,9 @@ private[spark] class ExternalSorter[K, V, C](
670673
}
671674

672675
/**
673-
* Write all the data added into this ExternalSorter into a file in the disk store. This is
674-
* called by the SortShuffleWriter.
675-
*
676-
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
677-
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
676+
* TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL
677+
* project. We should figure out an alternative way to test that so that we can remove this
678+
* otherwise unused code path.
678679
*/
679680
def writePartitionedFile(
680681
blockId: BlockId,
@@ -718,6 +719,77 @@ private[spark] class ExternalSorter[K, V, C](
718719
lengths
719720
}
720721

722+
/**
723+
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
724+
* to some arbitrary backing store. This is called by the SortShuffleWriter.
725+
*
726+
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
727+
*/
728+
def writePartitionedMapOutput(
729+
shuffleId: Int,
730+
mapId: Int,
731+
mapOutputWriter: ShuffleMapOutputWriter): Unit = {
732+
var nextPartitionId = 0
733+
if (spills.isEmpty) {
734+
// Case where we only have in-memory data
735+
val collection = if (aggregator.isDefined) map else buffer
736+
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
737+
while (it.hasNext()) {
738+
val partitionId = it.nextPartition()
739+
var partitionWriter: ShufflePartitionWriter = null
740+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
741+
TryUtils.tryWithSafeFinally {
742+
partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
743+
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
744+
partitionPairsWriter = new ShufflePartitionPairsWriter(
745+
partitionWriter,
746+
serializerManager,
747+
serInstance,
748+
blockId,
749+
context.taskMetrics().shuffleWriteMetrics)
750+
while (it.hasNext && it.nextPartition() == partitionId) {
751+
it.writeNext(partitionPairsWriter)
752+
}
753+
} {
754+
if (partitionPairsWriter != null) {
755+
partitionPairsWriter.close()
756+
}
757+
}
758+
nextPartitionId = partitionId + 1
759+
}
760+
} else {
761+
// We must perform merge-sort; get an iterator by partition and write everything directly.
762+
for ((id, elements) <- this.partitionedIterator) {
763+
val blockId = ShuffleBlockId(shuffleId, mapId, id)
764+
var partitionWriter: ShufflePartitionWriter = null
765+
var partitionPairsWriter: ShufflePartitionPairsWriter = null
766+
TryUtils.tryWithSafeFinally {
767+
partitionWriter = mapOutputWriter.getPartitionWriter(id)
768+
partitionPairsWriter = new ShufflePartitionPairsWriter(
769+
partitionWriter,
770+
serializerManager,
771+
serInstance,
772+
blockId,
773+
context.taskMetrics().shuffleWriteMetrics)
774+
if (elements.hasNext) {
775+
for (elem <- elements) {
776+
partitionPairsWriter.write(elem._1, elem._2)
777+
}
778+
}
779+
} {
780+
if (partitionPairsWriter != null) {
781+
partitionPairsWriter.close()
782+
}
783+
}
784+
nextPartitionId = id + 1
785+
}
786+
}
787+
788+
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
789+
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
790+
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
791+
}
792+
721793
def stop(): Unit = {
722794
spills.foreach(s => s.file.delete())
723795
spills.clear()
@@ -781,7 +853,7 @@ private[spark] class ExternalSorter[K, V, C](
781853
val inMemoryIterator = new WritablePartitionedIterator {
782854
private[this] var cur = if (upstream.hasNext) upstream.next() else null
783855

784-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
856+
def writeNext(writer: PairsWriter): Unit = {
785857
writer.write(cur._1._2, cur._2)
786858
cur = if (upstream.hasNext) upstream.next() else null
787859
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection
19+
20+
/**
21+
* An abstraction of a consumer of key-value pairs, primarily used when
22+
* persisting partitioned data, either through the shuffle writer plugins
23+
* or via DiskBlockObjectWriter.
24+
*/
25+
private[spark] trait PairsWriter {
26+
27+
def write(key: Any, value: Any): Unit
28+
}

core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
5252
new WritablePartitionedIterator {
5353
private[this] var cur = if (it.hasNext) it.next() else null
5454

55-
def writeNext(writer: DiskBlockObjectWriter): Unit = {
55+
def writeNext(writer: PairsWriter): Unit = {
5656
writer.write(cur._1._2, cur._2)
5757
cur = if (it.hasNext) it.next() else null
5858
}
@@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection {
8989
* has an associated partition.
9090
*/
9191
private[spark] trait WritablePartitionedIterator {
92-
def writeNext(writer: DiskBlockObjectWriter): Unit
92+
def writeNext(writer: PairsWriter): Unit
9393

9494
def hasNext(): Boolean
9595

0 commit comments

Comments
 (0)