Skip to content

Commit 6642de3

Browse files
committed
Allow Spark to recompute all the shuffle blocks on a host, if external shuffle service is unavailable on that host
cr https://cr.amazon.com/r/6822886/
1 parent 3fada2f commit 6642de3

File tree

3 files changed

+161
-35
lines changed

3 files changed

+161
-35
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ class DAGScheduler(
172172

173173
// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
174174
// every task. When we detect a node failing, we note the current epoch number and failed
175-
// executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results.
175+
// executor or host, increment it for new tasks, and use this to ignore stray
176+
// ShuffleMapTask results.
176177
//
177178
// TODO: Garbage collect information about failure epochs when we know there are no more
178179
// stray messages to detect.
@@ -1348,7 +1349,14 @@ class DAGScheduler(
13481349

13491350
// TODO: mark the executor as failed only if there were lots of fetch failures on it
13501351
if (bmAddress != null) {
1351-
handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch))
1352+
if (env.blockManager.externalShuffleServiceEnabled) {
1353+
val currentEpoch = Some(task.epoch).getOrElse(mapOutputTracker.getEpoch)
1354+
removeExecutor(bmAddress.executorId, currentEpoch)
1355+
handleExternalShuffleFailure(bmAddress.host, currentEpoch)
1356+
}
1357+
else {
1358+
handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch))
1359+
}
13521360
}
13531361
}
13541362

@@ -1368,6 +1376,30 @@ class DAGScheduler(
13681376
}
13691377
}
13701378

1379+
/**
1380+
* Removes an executor from the driver endpoint.
1381+
*
1382+
* @param execId id of the executor to be removed
1383+
* @param currentEpoch epoch during which the executor failure was caught to avoid allowing
1384+
* stray failures from possibly retriggering the detection of an
1385+
* executor as lost.
1386+
*
1387+
* @return boolean value indicating whether the executor was removed or not
1388+
*/
1389+
private[scheduler] def removeExecutor(execId: String, currentEpoch: Long): Boolean = {
1390+
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
1391+
failedEpoch(execId) = currentEpoch
1392+
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
1393+
blockManagerMaster.removeExecutor(execId)
1394+
true
1395+
}
1396+
else {
1397+
logDebug("Additional executor lost message for " + execId +
1398+
"(epoch " + currentEpoch + ")")
1399+
false
1400+
}
1401+
}
1402+
13711403
/**
13721404
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
13731405
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
@@ -1385,38 +1417,76 @@ class DAGScheduler(
13851417
filesLost: Boolean,
13861418
maybeEpoch: Option[Long] = None) {
13871419
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
1388-
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
1389-
failedEpoch(execId) = currentEpoch
1390-
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
1391-
blockManagerMaster.removeExecutor(execId)
1420+
val executorRemoved = removeExecutor(execId, currentEpoch)
1421+
if (executorRemoved && (filesLost || !env.blockManager.externalShuffleServiceEnabled)) {
1422+
handleInternalShuffleFailure(execId, currentEpoch)
1423+
}
1424+
}
13921425

1393-
if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
1394-
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
1395-
// TODO: This will be really slow if we keep accumulating shuffle map stages
1396-
for ((shuffleId, stage) <- shuffleIdToMapStage) {
1397-
stage.removeOutputsOnExecutor(execId)
1398-
mapOutputTracker.registerMapOutputs(
1399-
shuffleId,
1400-
stage.outputLocInMapOutputTrackerFormat(),
1401-
changeEpoch = true)
1402-
}
1403-
if (shuffleIdToMapStage.isEmpty) {
1404-
mapOutputTracker.incrementEpoch()
1405-
}
1406-
clearCacheLocs()
1407-
}
1426+
/**
1427+
* Responds to an internal shuffle becoming unavailable for an executor.
1428+
*
1429+
* We will assume that we've lost all the shuffle blocks for the executor.
1430+
*
1431+
* @param execId id of the executor for which internal shuffle is unavailable
1432+
* @param currentEpoch epoch during which the failure was caught.
1433+
*/
1434+
private[scheduler] def handleInternalShuffleFailure(execId: String, currentEpoch: Long): Unit = {
1435+
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
1436+
cleanShuffleOutputs((stage: ShuffleMapStage) => {
1437+
stage.removeOutputsOnExecutor(execId)
1438+
})
1439+
}
1440+
1441+
/**
1442+
* Responds to an external shuffle service becoming unavailable on a host.
1443+
*
1444+
* We will assume that we've lost all the shuffle blocks on that host if FetchFailed occurred
1445+
* while external shuffle is being used.
1446+
*
1447+
* @param host address of the host on which external shuffle is unavailable
1448+
* @param currentEpoch epoch during which the failure was caught. This is passed to avoid
1449+
* allowing stray fetch failures from possibly retriggering the detection
1450+
* of external shuffle service becoming unavailable.
1451+
*/
1452+
private[scheduler] def handleExternalShuffleFailure(host: String, currentEpoch: Long): Unit = {
1453+
if (!failedEpoch.contains(host) || failedEpoch(host) < currentEpoch) {
1454+
failedEpoch(host) = currentEpoch
1455+
logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
1456+
cleanShuffleOutputs((stage: ShuffleMapStage) => {
1457+
stage.removeOutputsOnHost(host)
1458+
})
14081459
} else {
1409-
logDebug("Additional executor lost message for " + execId +
1410-
"(epoch " + currentEpoch + ")")
1460+
logDebug(("Additional Shuffle files " +
1461+
"lost message for host: %s (epoch %d)").format(host, currentEpoch))
1462+
}
1463+
}
1464+
1465+
private[scheduler] def cleanShuffleOutputs(outputsCleaner: ShuffleMapStage => _): Unit = {
1466+
// TODO: This will be really slow if we keep accumulating shuffle map stages
1467+
for ((shuffleId, stage) <- shuffleIdToMapStage) {
1468+
outputsCleaner(stage)
1469+
mapOutputTracker.registerMapOutputs(
1470+
shuffleId,
1471+
stage.outputLocInMapOutputTrackerFormat(),
1472+
changeEpoch = true)
1473+
}
1474+
if (shuffleIdToMapStage.isEmpty) {
1475+
mapOutputTracker.incrementEpoch()
14111476
}
1477+
clearCacheLocs()
14121478
}
14131479

14141480
private[scheduler] def handleExecutorAdded(execId: String, host: String) {
14151481
// remove from failedEpoch(execId) ?
14161482
if (failedEpoch.contains(execId)) {
1417-
logInfo("Host added was in lost list earlier: " + host)
1483+
logInfo("Executor %s added was in lost list earlier.".format(execId))
14181484
failedEpoch -= execId
14191485
}
1486+
1487+
if (failedEpoch.contains(host)) {
1488+
failedEpoch -= host
1489+
}
14201490
}
14211491

14221492
private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]) {

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,25 +132,45 @@ private[spark] class ShuffleMapStage(
132132
outputLocs.map(_.headOption.orNull)
133133
}
134134

135-
/**
136-
* Removes all shuffle outputs associated with this executor. Note that this will also remove
137-
* outputs which are served by an external shuffle server (if one exists), as they are still
138-
* registered with this execId.
139-
*/
140-
def removeOutputsOnExecutor(execId: String): Unit = {
135+
private def removeOutputsHelper(locationChecker: BlockManagerId => Boolean): Boolean = {
141136
var becameUnavailable = false
142137
for (partition <- 0 until numPartitions) {
143138
val prevList = outputLocs(partition)
144-
val newList = prevList.filterNot(_.location.executorId == execId)
139+
val newList = prevList.filterNot(status => locationChecker(status.location))
145140
outputLocs(partition) = newList
146141
if (prevList != Nil && newList == Nil) {
147142
becameUnavailable = true
148143
_numAvailableOutputs -= 1
149144
}
150145
}
146+
becameUnavailable
147+
}
148+
149+
/**
150+
* Removes all shuffle outputs associated with this executor. Note that this will also remove
151+
* outputs which are served by an external shuffle server (if one exists), as they are still
152+
* registered with this execId.
153+
*/
154+
def removeOutputsOnExecutor(execId: String): Unit = {
155+
val becameUnavailable = removeOutputsHelper(
156+
(blockManagerId: BlockManagerId) => { blockManagerId.executorId == execId })
157+
151158
if (becameUnavailable) {
152159
logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
153160
this, execId, _numAvailableOutputs, numPartitions, isAvailable))
154161
}
155162
}
163+
164+
/**
165+
* Removes all shuffle outputs associated with the external shuffle service on this host.
166+
*/
167+
def removeOutputsOnHost(host: String): Unit = {
168+
val becameUnavailable = removeOutputsHelper(
169+
(blockManagerId: BlockManagerId) => { blockManagerId.host == host })
170+
171+
if (becameUnavailable) {
172+
logInfo("%s is now unavailable on host %s (%d/%d, %s)".format(
173+
this, host, _numAvailableOutputs, numPartitions, isAvailable))
174+
}
175+
}
156176
}

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,41 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
674674
}
675675
}
676676

677+
private val shuffleFetchFailureTests = Seq(
678+
("fetch failure with external shuffle service enabled", true, Set(0, 1, 2, 4)),
679+
("fetch failure with internal shuffle service enabled", false, Set(0, 1)))
680+
681+
for((eventDescription, shuffleServiceOn, expectedPartitionsLost)
682+
<- shuffleFetchFailureTests) {
683+
test(eventDescription) {
684+
afterEach()
685+
val conf = new SparkConf()
686+
conf.set("spark.shuffle.service.enabled", shuffleServiceOn.toString)
687+
init(conf)
688+
assert(sc.env.blockManager.externalShuffleServiceEnabled == shuffleServiceOn)
689+
690+
val shuffleMapRdd = new MyRDD(sc, 5, Nil)
691+
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
692+
val shuffleId = shuffleDep.shuffleId
693+
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
694+
submit(reduceRdd, Array(0, 1))
695+
696+
complete(taskSets(0), Seq(
697+
(Success, makeMapStatus("hostA", reduceRdd.partitions.length, 5, Some("exec-hostA-1"))),
698+
(Success, makeMapStatus("hostA", reduceRdd.partitions.length, 5, Some("exec-hostA-1"))),
699+
(Success, makeMapStatus("hostA", reduceRdd.partitions.length, 5, Some("exec-hostA-2"))),
700+
(Success, makeMapStatus("hostB", reduceRdd.partitions.length, 5, Some("exec-hostB-1"))),
701+
(Success, makeMapStatus("hostA", reduceRdd.partitions.length, 5, Some("exec-hostA-2")))))
702+
703+
complete(taskSets(1), Seq(
704+
(Success, 42),
705+
(FetchFailed(makeBlockManagerId("hostA", Some("exec-hostA-1")),
706+
shuffleId, 0, 0, "ignored"), null)))
707+
scheduler.resubmitFailedStages()
708+
assert(taskSets(2).tasks.map(_.partitionId).toSet === expectedPartitionsLost)
709+
}
710+
}
711+
677712
// Helper function to validate state when creating tests for task failures
678713
private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) {
679714
assert(stageAttempt.stageId === stageId)
@@ -2330,9 +2365,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
23302365
}
23312366

23322367
object DAGSchedulerSuite {
2333-
def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus =
2334-
MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes))
2368+
def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2,
2369+
execId: Option[String] = None): MapStatus =
2370+
MapStatus(makeBlockManagerId(host, execId), Array.fill[Long](reduces)(sizes))
23352371

2336-
def makeBlockManagerId(host: String): BlockManagerId =
2337-
BlockManagerId("exec-" + host, host, 12345)
2372+
def makeBlockManagerId(host: String, execId: Option[String] = None): BlockManagerId =
2373+
BlockManagerId(execId.getOrElse("exec-" + host), host, 12345)
23382374
}

0 commit comments

Comments
 (0)