Skip to content

Commit 6fcb70e

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-32055][CORE][SQL] Unify getReader and getReaderForRange in ShuffleManager
### What changes were proposed in this pull request? This PR tries to unify the method `getReader` and `getReaderForRange` in `ShuffleManager`. ### Why are the changes needed? Reduce the duplicate codes, simplify the implementation, and for better maintenance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Covered by existing tests. Closes #28895 from Ngone51/unify-getreader. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4204a63 commit 6fcb70e

File tree

6 files changed

+40
-87
lines changed

6 files changed

+40
-87
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -322,36 +322,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
322322
// For testing
323323
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
324324
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
325-
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
325+
getMapSizesByExecutorId(shuffleId, 0, Int.MaxValue, reduceId, reduceId + 1)
326326
}
327327

328328
/**
329329
* Called from executors to get the server URIs and output sizes for each shuffle block that
330330
* needs to be read from a given range of map output partitions (startPartition is included but
331-
* endPartition is excluded from the range).
331+
* endPartition is excluded from the range) within a range of mappers (startMapIndex is included
332+
* but endMapIndex is excluded). If endMapIndex=Int.MaxValue, the actual endMapIndex will be
333+
* changed to the length of total map outputs.
332334
*
333335
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
334336
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
335337
* tuples describing the shuffle blocks that are stored at that block manager.
338+
* Note that zero-sized blocks are excluded in the result.
336339
*/
337340
def getMapSizesByExecutorId(
338-
shuffleId: Int,
339-
startPartition: Int,
340-
endPartition: Int)
341-
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
342-
343-
/**
344-
* Called from executors to get the server URIs and output sizes for each shuffle block that
345-
* needs to be read from a given range of map output partitions (startPartition is included but
346-
* endPartition is excluded from the range) and is produced by
347-
* a range of mappers (startMapIndex, endMapIndex, startMapIndex is included and
348-
* the endMapIndex is excluded).
349-
*
350-
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
351-
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
352-
* tuples describing the shuffle blocks that are stored at that block manager.
353-
*/
354-
def getMapSizesByRange(
355341
shuffleId: Int,
356342
startMapIndex: Int,
357343
endMapIndex: Int,
@@ -734,38 +720,22 @@ private[spark] class MapOutputTrackerMaster(
734720
}
735721
}
736722

737-
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
738723
// This method is only called in local-mode.
739724
def getMapSizesByExecutorId(
740-
shuffleId: Int,
741-
startPartition: Int,
742-
endPartition: Int)
743-
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
744-
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
745-
shuffleStatuses.get(shuffleId) match {
746-
case Some (shuffleStatus) =>
747-
shuffleStatus.withMapStatuses { statuses =>
748-
MapOutputTracker.convertMapStatuses(
749-
shuffleId, startPartition, endPartition, statuses, 0, shuffleStatus.mapStatuses.length)
750-
}
751-
case None =>
752-
Iterator.empty
753-
}
754-
}
755-
756-
override def getMapSizesByRange(
757725
shuffleId: Int,
758726
startMapIndex: Int,
759727
endMapIndex: Int,
760728
startPartition: Int,
761729
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
762-
logDebug(s"Fetching outputs for shuffle $shuffleId, mappers $startMapIndex-$endMapIndex" +
763-
s"partitions $startPartition-$endPartition")
730+
logDebug(s"Fetching outputs for shuffle $shuffleId")
764731
shuffleStatuses.get(shuffleId) match {
765732
case Some(shuffleStatus) =>
766733
shuffleStatus.withMapStatuses { statuses =>
734+
val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
735+
logDebug(s"Convert map statuses for shuffle $shuffleId, " +
736+
s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
767737
MapOutputTracker.convertMapStatuses(
768-
shuffleId, startPartition, endPartition, statuses, startMapIndex, endMapIndex)
738+
shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
769739
}
770740
case None =>
771741
Iterator.empty
@@ -798,37 +768,20 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
798768
*/
799769
private val fetchingLock = new KeyLock[Int]
800770

801-
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
802771
override def getMapSizesByExecutorId(
803-
shuffleId: Int,
804-
startPartition: Int,
805-
endPartition: Int)
806-
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
807-
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
808-
val statuses = getStatuses(shuffleId, conf)
809-
try {
810-
MapOutputTracker.convertMapStatuses(
811-
shuffleId, startPartition, endPartition, statuses, 0, statuses.length)
812-
} catch {
813-
case e: MetadataFetchFailedException =>
814-
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
815-
mapStatuses.clear()
816-
throw e
817-
}
818-
}
819-
820-
override def getMapSizesByRange(
821772
shuffleId: Int,
822773
startMapIndex: Int,
823774
endMapIndex: Int,
824775
startPartition: Int,
825776
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
826-
logDebug(s"Fetching outputs for shuffle $shuffleId, mappers $startMapIndex-$endMapIndex" +
827-
s"partitions $startPartition-$endPartition")
777+
logDebug(s"Fetching outputs for shuffle $shuffleId")
828778
val statuses = getStatuses(shuffleId, conf)
829779
try {
780+
val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
781+
logDebug(s"Convert map statuses for shuffle $shuffleId, " +
782+
s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
830783
MapOutputTracker.convertMapStatuses(
831-
shuffleId, startPartition, endPartition, statuses, startMapIndex, endMapIndex)
784+
shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
832785
} catch {
833786
case e: MetadataFetchFailedException =>
834787
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:

core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,31 @@ private[spark] trait ShuffleManager {
4343
context: TaskContext,
4444
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
4545

46+
4647
/**
47-
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
48+
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
49+
* read from all map outputs of the shuffle.
50+
*
4851
* Called on executors by reduce tasks.
4952
*/
50-
def getReader[K, C](
53+
final def getReader[K, C](
5154
handle: ShuffleHandle,
5255
startPartition: Int,
5356
endPartition: Int,
5457
context: TaskContext,
55-
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C]
58+
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
59+
getReader(handle, 0, Int.MaxValue, startPartition, endPartition, context, metrics)
60+
}
5661

5762
/**
5863
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
59-
* read from map output (startMapIndex to endMapIndex - 1, inclusive).
64+
* read from a range of map outputs(startMapIndex to endMapIndex-1, inclusive).
65+
* If endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length of total map
66+
* outputs of the shuffle in `getMapSizesByExecutorId`.
67+
*
6068
* Called on executors by reduce tasks.
6169
*/
62-
def getReaderForRange[K, C](
70+
def getReader[K, C](
6371
handle: ShuffleHandle,
6472
startMapIndex: Int,
6573
endMapIndex: Int,

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.collection.JavaConverters._
2323

2424
import org.apache.spark._
2525
import org.apache.spark.internal.{config, Logging}
26+
import org.apache.spark.scheduler.MapStatus
2627
import org.apache.spark.shuffle._
2728
import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents}
2829
import org.apache.spark.util.Utils
@@ -115,31 +116,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
115116
}
116117

117118
/**
118-
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
119+
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to
120+
* read from a range of map outputs(startMapIndex to endMapIndex-1, inclusive).
121+
* If endMapIndex=Int.MaxValue, the actual endMapIndex will be changed to the length of total map
122+
* outputs of the shuffle in `getMapSizesByExecutorId`.
123+
*
119124
* Called on executors by reduce tasks.
120125
*/
121126
override def getReader[K, C](
122-
handle: ShuffleHandle,
123-
startPartition: Int,
124-
endPartition: Int,
125-
context: TaskContext,
126-
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
127-
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
128-
handle.shuffleId, startPartition, endPartition)
129-
new BlockStoreShuffleReader(
130-
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
131-
shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
132-
}
133-
134-
override def getReaderForRange[K, C](
135127
handle: ShuffleHandle,
136128
startMapIndex: Int,
137129
endMapIndex: Int,
138130
startPartition: Int,
139131
endPartition: Int,
140132
context: TaskContext,
141133
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
142-
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange(
134+
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
143135
handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
144136
new BlockStoreShuffleReader(
145137
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,

core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
317317
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
318318
Array(size10000, size0, size1000, size0), 6))
319319
assert(tracker.containsShuffle(10))
320-
assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq ===
320+
assert(tracker.getMapSizesByExecutorId(10, 0, 2, 0, 4).toSeq ===
321321
Seq(
322322
(BlockManagerId("a", "hostA", 1000),
323323
Seq((ShuffleBlockId(10, 5, 1), size1000, 0),

core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
104104
// shuffle data to read.
105105
val mapOutputTracker = mock(classOf[MapOutputTracker])
106106
when(mapOutputTracker.getMapSizesByExecutorId(
107-
shuffleId, reduceId, reduceId + 1)).thenReturn {
107+
shuffleId, 0, numMaps, reduceId, reduceId + 1)).thenReturn {
108108
// Test a scenario where all data is local, to avoid creating a bunch of additional mocks
109109
// for the code to read data over the network.
110110
val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
@@ -132,7 +132,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
132132
val taskContext = TaskContext.empty()
133133
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
134134
val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId(
135-
shuffleId, reduceId, reduceId + 1)
135+
shuffleId, 0, numMaps, reduceId, reduceId + 1)
136136
val shuffleReader = new BlockStoreShuffleReader(
137137
shuffleHandle,
138138
blocksByAddress,

sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class ShuffledRowRDD(
191191
sqlMetricsReporter)
192192

193193
case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) =>
194-
SparkEnv.get.shuffleManager.getReaderForRange(
194+
SparkEnv.get.shuffleManager.getReader(
195195
dependency.shuffleHandle,
196196
startMapIndex,
197197
endMapIndex,
@@ -201,7 +201,7 @@ class ShuffledRowRDD(
201201
sqlMetricsReporter)
202202

203203
case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) =>
204-
SparkEnv.get.shuffleManager.getReaderForRange(
204+
SparkEnv.get.shuffleManager.getReader(
205205
dependency.shuffleHandle,
206206
mapIndex,
207207
mapIndex + 1,

0 commit comments

Comments
 (0)