Skip to content

Commit 00e78b2

Browse files
committed
Add mapTaskId in FetchFailedException, address some comments
1 parent 8b51720 commit 00e78b2

File tree

17 files changed

+79
-64
lines changed

17 files changed

+79
-64
lines changed

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ public boolean hasNext() {
337337
// mapTaskIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks
338338
// must have non-empty mapTaskIds and reduceIds, see the checking logic in
339339
// OneForOneBlockFetcher.
340+
assert(mapTaskIds.length != 0 && mapTaskIds.length == reduceIds.length);
340341
return mapIdx < mapTaskIds.length && reduceIdx < reduceIds[mapIdx].length;
341342
}
342343

core/src/main/scala/org/apache/spark/TaskEndReason.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,15 @@ case object Resubmitted extends TaskFailedReason {
8383
case class FetchFailed(
8484
bmAddress: BlockManagerId, // Note that bmAddress can be null
8585
shuffleId: Int,
86+
mapTaskId: Long,
8687
mapIndex: Int,
8788
reduceId: Int,
8889
message: String)
8990
extends TaskFailedReason {
9091
override def toErrorString: String = {
9192
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
9293
s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndex, " +
93-
s"reduceId=$reduceId, message=\n$message\n)"
94+
s"mapTaskId=$mapTaskId, reduceId=$reduceId, message=\n$message\n)"
9495
}
9596

9697
/**

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,7 @@ private[spark] class DAGScheduler(
15091509
}
15101510
}
15111511

1512-
case FetchFailed(bmAddress, shuffleId, mapIndex, _, failureMessage) =>
1512+
case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) =>
15131513
val failedStage = stageIdToStage(task.stageId)
15141514
val mapStage = shuffleIdToMapStage(shuffleId)
15151515

core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,19 @@ private[spark] object MapStatus {
108108
*
109109
* @param loc location where the task is being executed.
110110
* @param compressedSizes size of the blocks, indexed by reduce partition id.
111-
* @param mapTId unique task id for the task
111+
* @param _mapTaskId unique task id for the task
112112
*/
113113
private[spark] class CompressedMapStatus(
114114
private[this] var loc: BlockManagerId,
115115
private[this] var compressedSizes: Array[Byte],
116-
private[this] var mapTId: Long)
116+
private[this] var _mapTaskId: Long)
117117
extends MapStatus with Externalizable {
118118

119119
// For deserialization only
120120
protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1)
121121

122-
def this(loc: BlockManagerId, uncompressedSizes: Array[Long], taskAttemptId: Long) {
123-
this(loc, uncompressedSizes.map(MapStatus.compressSize), taskAttemptId)
122+
def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) {
123+
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId)
124124
}
125125

126126
override def location: BlockManagerId = loc
@@ -129,21 +129,21 @@ private[spark] class CompressedMapStatus(
129129
MapStatus.decompressSize(compressedSizes(reduceId))
130130
}
131131

132-
override def mapTaskId: Long = mapTId
132+
override def mapTaskId: Long = _mapTaskId
133133

134134
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
135135
loc.writeExternal(out)
136136
out.writeInt(compressedSizes.length)
137137
out.write(compressedSizes)
138-
out.writeLong(mapTId)
138+
out.writeLong(_mapTaskId)
139139
}
140140

141141
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
142142
loc = BlockManagerId(in)
143143
val len = in.readInt()
144144
compressedSizes = new Array[Byte](len)
145145
in.readFully(compressedSizes)
146-
mapTId = in.readLong()
146+
_mapTaskId = in.readLong()
147147
}
148148
}
149149

@@ -157,20 +157,20 @@ private[spark] class CompressedMapStatus(
157157
* @param emptyBlocks a bitmap tracking which blocks are empty
158158
* @param avgSize average size of the non-empty and non-huge blocks
159159
* @param hugeBlockSizes sizes of huge blocks by their reduceId.
160-
* @param mapTId unique task id for the task
160+
* @param _mapTaskId unique task id for the task
161161
*/
162162
private[spark] class HighlyCompressedMapStatus private (
163163
private[this] var loc: BlockManagerId,
164164
private[this] var numNonEmptyBlocks: Int,
165165
private[this] var emptyBlocks: RoaringBitmap,
166166
private[this] var avgSize: Long,
167167
private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
168-
private[this] var mapTId: Long)
168+
private[this] var _mapTaskId: Long)
169169
extends MapStatus with Externalizable {
170170

171171
// loc could be null when the default constructor is called during deserialization
172172
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0
173-
|| numNonEmptyBlocks == 0 || mapTId > 0,
173+
|| numNonEmptyBlocks == 0 || _mapTaskId > 0,
174174
"Average size can only be zero for map stages that produced no output")
175175

176176
protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only
@@ -189,7 +189,7 @@ private[spark] class HighlyCompressedMapStatus private (
189189
}
190190
}
191191

192-
override def mapTaskId: Long = mapTId
192+
override def mapTaskId: Long = _mapTaskId
193193

194194
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
195195
loc.writeExternal(out)
@@ -200,7 +200,7 @@ private[spark] class HighlyCompressedMapStatus private (
200200
out.writeInt(kv._1)
201201
out.writeByte(kv._2)
202202
}
203-
out.writeLong(mapTId)
203+
out.writeLong(_mapTaskId)
204204
}
205205

206206
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -216,7 +216,7 @@ private[spark] class HighlyCompressedMapStatus private (
216216
hugeBlockSizesImpl(block) = size
217217
}
218218
hugeBlockSizes = hugeBlockSizesImpl
219-
mapTId = in.readLong()
219+
_mapTaskId = in.readLong()
220220
}
221221
}
222222

core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.util.Utils
3535
private[spark] class FetchFailedException(
3636
bmAddress: BlockManagerId,
3737
shuffleId: Int,
38+
mapTaskId: Long,
3839
mapIndex: Int,
3940
reduceId: Int,
4041
message: String,
@@ -44,10 +45,11 @@ private[spark] class FetchFailedException(
4445
def this(
4546
bmAddress: BlockManagerId,
4647
shuffleId: Int,
48+
mapTaskId: Long,
4749
mapIndex: Int,
4850
reduceId: Int,
4951
cause: Throwable) {
50-
this(bmAddress, shuffleId, mapIndex, reduceId, cause.getMessage, cause)
52+
this(bmAddress, shuffleId, mapTaskId, mapIndex, reduceId, cause.getMessage, cause)
5153
}
5254

5355
// SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
@@ -56,8 +58,8 @@ private[spark] class FetchFailedException(
5658
// because the TaskContext is not defined in some test cases.
5759
Option(TaskContext.get()).map(_.setFetchFailed(this))
5860

59-
def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapIndex, reduceId,
60-
Utils.exceptionString(this))
61+
def toTaskFailedReason: TaskFailedReason = FetchFailed(
62+
bmAddress, shuffleId, mapTaskId, mapIndex, reduceId, Utils.exceptionString(this))
6163
}
6264

6365
/**
@@ -67,4 +69,4 @@ private[spark] class MetadataFetchFailedException(
6769
shuffleId: Int,
6870
reduceId: Int,
6971
message: String)
70-
extends FetchFailedException(null, shuffleId, -1, reduceId, message)
72+
extends FetchFailedException(null, shuffleId, -1L, -1, reduceId, message)

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,8 @@ final class ShuffleBlockFetcherIterator(
578578
address: BlockManagerId,
579579
e: Throwable) = {
580580
blockId match {
581-
case ShuffleBlockId(shufId, _, reduceId) =>
582-
throw new FetchFailedException(address, shufId, mapIndex, reduceId, e)
581+
case ShuffleBlockId(shufId, mapTaskId, reduceId) =>
582+
throw new FetchFailedException(address, shufId, mapTaskId, mapIndex, reduceId, e)
583583
case _ =>
584584
throw new SparkException(
585585
"Failed to get block " + blockId + ", which is not a shuffle block", e)

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,8 @@ private[spark] object JsonProtocol {
419419
map(blockManagerIdToJson).getOrElse(JNothing)
420420
("Block Manager Address" -> blockManagerAddress) ~
421421
("Shuffle ID" -> fetchFailed.shuffleId) ~
422-
("Map ID" -> fetchFailed.mapIndex) ~
422+
("Map Task ID" -> fetchFailed.mapTaskId) ~
423+
("Map Index" -> fetchFailed.mapIndex) ~
423424
("Reduce ID" -> fetchFailed.reduceId) ~
424425
("Message" -> fetchFailed.message)
425426
case exceptionFailure: ExceptionFailure =>
@@ -974,11 +975,12 @@ private[spark] object JsonProtocol {
974975
case `fetchFailed` =>
975976
val blockManagerAddress = blockManagerIdFromJson(json \ "Block Manager Address")
976977
val shuffleId = (json \ "Shuffle ID").extract[Int]
977-
val mapId = (json \ "Map ID").extract[Int]
978+
val mapTaskId = (json \ "Map Task ID").extract[Long]
979+
val mapIndex = (json \ "Map Index").extract[Int]
978980
val reduceId = (json \ "Reduce ID").extract[Int]
979981
val message = jsonOption(json \ "Message").map(_.extract[String])
980-
new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId,
981-
message.getOrElse("Unknown reason"))
982+
new FetchFailed(blockManagerAddress, shuffleId, mapTaskId, mapIndex,
983+
reduceId, message.getOrElse("Unknown reason"))
982984
case `exceptionFailure` =>
983985
val className = (json \ "Class Name").extract[String]
984986
val description = (json \ "Description").extract[String]

core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
142142
sid,
143143
taskContext.partitionId(),
144144
taskContext.partitionId(),
145+
taskContext.partitionId(),
145146
"simulated fetch failure")
146147
} else {
147148
iter

core/src/test/scala/org/apache/spark/SparkContextSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
705705
if (context.stageAttemptNumber == 0) {
706706
if (context.partitionId == 0) {
707707
// Make the first task in the first stage attempt fail.
708-
throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0, 0,
708+
throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0L, 0, 0,
709709
new java.io.IOException("fake"))
710710
} else {
711711
// Make the second task in the first stage attempt sleep to generate a zombie task

core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
528528
throw new FetchFailedException(
529529
bmAddress = BlockManagerId("1", "hostA", 1234),
530530
shuffleId = 0,
531+
mapTaskId = 0L,
531532
mapIndex = 0,
532533
reduceId = 0,
533534
message = "fake fetch failure"

0 commit comments

Comments
 (0)