Skip to content

Commit 3be31e8

Browse files
committed
[CORE] Fix regressions in decommissioning
The DecommissionWorkerSuite started becoming flaky and it revealed a real regression. Recent PR's (apache#28085 and apache#29211) neccessitate a small reworking of the decommissioning logic. Before getting into that, let me describe the intended behavior of decommissioning: If a fetch failure happens where the source executor was decommissioned, we want to treat that as an eager signal to clear all shuffle state associated with that executor. In addition if we know that the host was decommissioned, we want to forget about all map statuses from all other executors on that decommissioned host. This is what the test "decommission workers ensure that fetch failures lead to rerun" is trying to test. This invariant is important to ensure that decommissioning a host does not lead to multiple fetch failures that might fail the job. - Per apache#29211, the executors now eagerly exit on decommissioning and thus the executor is lost before the fetch failure even happens. (I tested this by waiting some seconds before triggering the fetch failure). When an executor is lost, we forget its decommissioning information. The fix is to keep the decommissioning information around for some time after removal with some extra logic to finally purge it after a timeout. - Per apache#28085, when the executor is lost, it forgets the shuffle state about just that executor and increments the shuffleFileLostEpoch. This incrementing precludes the clearing of state of the entire host when the fetch failure happens. This PR elects to only change this codepath for the special case of decommissioning, without any other side effects. This whole version keeping stuff is complex and it has effectively not been semantically changed since 2013! The fix here is also simple: Ignore the shuffleFileLostEpoch when the shuffle status is being cleared due to a fetch failure resulting from host decommission. These two fixes are local to decommissioning only and don't change other behavior. Also added some more tests to TaskSchedulerImpl to ensure that the decommissioning information is indeed purged after a timeout. Also hardened the test DecommissionWorkerSuite to make it wait for successful job completion.
1 parent 1a4c8f7 commit 3be31e8

File tree

5 files changed

+131
-37
lines changed

5 files changed

+131
-37
lines changed

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,13 @@ private[spark] class CoarseGrainedExecutorBackend(
294294
override def run(): Unit = {
295295
var lastTaskRunningTime = System.nanoTime()
296296
val sleep_time = 1000 // 1s
297-
297+
val initialSleepMillis = env.conf.getInt(
298+
"spark.executor.decommission.initial.sleep.millis", sleep_time)
299+
if (initialSleepMillis > 0) {
300+
Thread.sleep(initialSleepMillis)
301+
}
298302
while (true) {
299303
logInfo("Checking to see if we can shutdown.")
300-
Thread.sleep(sleep_time)
301304
if (executor == null || executor.numRunningTasks == 0) {
302305
if (env.conf.get(STORAGE_DECOMMISSION_ENABLED)) {
303306
logInfo("No running tasks, checking migrations")
@@ -323,6 +326,7 @@ private[spark] class CoarseGrainedExecutorBackend(
323326
// move forward.
324327
lastTaskRunningTime = System.nanoTime()
325328
}
329+
Thread.sleep(sleep_time)
326330
}
327331
}
328332
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,8 @@ private[spark] class DAGScheduler(
18461846
execId = bmAddress.executorId,
18471847
fileLost = true,
18481848
hostToUnregisterOutputs = hostToUnregisterOutputs,
1849-
maybeEpoch = Some(task.epoch))
1849+
maybeEpoch = Some(task.epoch),
1850+
ignoreShuffleVersion = isHostDecommissioned)
18501851
}
18511852
}
18521853

@@ -2012,7 +2013,8 @@ private[spark] class DAGScheduler(
20122013
execId: String,
20132014
fileLost: Boolean,
20142015
hostToUnregisterOutputs: Option[String],
2015-
maybeEpoch: Option[Long] = None): Unit = {
2016+
maybeEpoch: Option[Long] = None,
2017+
ignoreShuffleVersion: Boolean = false): Unit = {
20162018
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
20172019
logDebug(s"Considering removal of executor $execId; " +
20182020
s"fileLost: $fileLost, currentEpoch: $currentEpoch")
@@ -2022,16 +2024,25 @@ private[spark] class DAGScheduler(
20222024
blockManagerMaster.removeExecutor(execId)
20232025
clearCacheLocs()
20242026
}
2025-
if (fileLost &&
2026-
(!shuffleFileLostEpoch.contains(execId) || shuffleFileLostEpoch(execId) < currentEpoch)) {
2027-
shuffleFileLostEpoch(execId) = currentEpoch
2028-
hostToUnregisterOutputs match {
2029-
case Some(host) =>
2030-
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
2031-
mapOutputTracker.removeOutputsOnHost(host)
2032-
case None =>
2033-
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
2034-
mapOutputTracker.removeOutputsOnExecutor(execId)
2027+
if (fileLost) {
2028+
val remove = if (ignoreShuffleVersion) {
2029+
true
2030+
} else if (!shuffleFileLostEpoch.contains(execId) ||
2031+
shuffleFileLostEpoch(execId) < currentEpoch) {
2032+
shuffleFileLostEpoch(execId) = currentEpoch
2033+
true
2034+
} else {
2035+
false
2036+
}
2037+
if (remove) {
2038+
hostToUnregisterOutputs match {
2039+
case Some(host) =>
2040+
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
2041+
mapOutputTracker.removeOutputsOnHost(host)
2042+
case None =>
2043+
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
2044+
mapOutputTracker.removeOutputsOnExecutor(execId)
2045+
}
20352046
}
20362047
}
20372048
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.nio.ByteBuffer
21+
import java.util
2122
import java.util.{Timer, TimerTask}
2223
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
2324
import java.util.concurrent.atomic.AtomicLong
@@ -136,7 +137,9 @@ private[spark] class TaskSchedulerImpl(
136137
// IDs of the tasks running on each executor
137138
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
138139

139-
private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
140+
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
141+
// map of second to list of executors to clear form the above map
142+
val decommissioningExecutorsToGc = new util.TreeMap[Long, mutable.ArrayBuffer[String]]()
140143

141144
def runningTasksByExecutors: Map[String, Int] = synchronized {
142145
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
@@ -910,7 +913,7 @@ private[spark] class TaskSchedulerImpl(
910913
// if we heard isHostDecommissioned ever true, then we keep that one since it is
911914
// most likely coming from the cluster manager and thus authoritative
912915
val oldDecomInfo = executorsPendingDecommission.get(executorId)
913-
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
916+
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
914917
executorsPendingDecommission(executorId) = decommissionInfo
915918
}
916919
}
@@ -921,7 +924,13 @@ private[spark] class TaskSchedulerImpl(
921924

922925
override def getExecutorDecommissionInfo(executorId: String)
923926
: Option[ExecutorDecommissionInfo] = synchronized {
924-
executorsPendingDecommission.get(executorId)
927+
import scala.collection.JavaConverters._
928+
// Garbage collect old decommissioning entries
929+
val secondsToGcUptil = TimeUnit.MILLISECONDS.toSeconds(clock.getTimeMillis())
930+
val headMap = decommissioningExecutorsToGc.headMap(secondsToGcUptil)
931+
headMap.values().asScala.flatten.foreach(executorsPendingDecommission -= _)
932+
headMap.clear()
933+
executorsPendingDecommission.get(executorId)
925934
}
926935

927936
override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
@@ -1027,7 +1036,15 @@ private[spark] class TaskSchedulerImpl(
10271036
}
10281037
}
10291038

1030-
executorsPendingDecommission -= executorId
1039+
1040+
val decomInfo = executorsPendingDecommission.get(executorId)
1041+
if (decomInfo.isDefined) {
1042+
val rememberSeconds =
1043+
conf.getInt("spark.decommissioningRememberAfterRemoval.seconds", 60)
1044+
val gcSecond = TimeUnit.MILLISECONDS.toSeconds(clock.getTimeMillis()) + rememberSeconds
1045+
decommissioningExecutorsToGc.computeIfAbsent(gcSecond, _ => mutable.ArrayBuffer.empty) +=
1046+
executorId
1047+
}
10311048

10321049
if (reason != LossReasonPending) {
10331050
executorIdToHost -= executorId

core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ class DecommissionWorkerSuite
8484
}
8585
}
8686

87+
// Unlike TestUtils.withListener, it also waits for the job to be done
88+
def withListener(sc: SparkContext, listener: RootStageAwareListener)
89+
(body: SparkListener => Unit): Unit = {
90+
sc.addSparkListener(listener)
91+
try {
92+
body(listener)
93+
sc.listenerBus.waitUntilEmpty()
94+
listener.waitForJobDone()
95+
} finally {
96+
sc.listenerBus.removeListener(listener)
97+
}
98+
}
99+
87100
test("decommission workers should not result in job failure") {
88101
val maxTaskFailures = 2
89102
val numTimesToKillWorkers = maxTaskFailures + 1
@@ -109,7 +122,7 @@ class DecommissionWorkerSuite
109122
}
110123
}
111124
}
112-
TestUtils.withListener(sc, listener) { _ =>
125+
withListener(sc, listener) { _ =>
113126
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
114127
Thread.sleep(5 * 1000L); 1
115128
}.count()
@@ -164,7 +177,7 @@ class DecommissionWorkerSuite
164177
}
165178
}
166179
}
167-
TestUtils.withListener(sc, listener) { _ =>
180+
withListener(sc, listener) { _ =>
168181
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
169182
val sleepTimeSeconds = if (pid == 0) 1 else 10
170183
Thread.sleep(sleepTimeSeconds * 1000L)
@@ -190,10 +203,11 @@ class DecommissionWorkerSuite
190203
}
191204
}
192205

193-
test("decommission workers ensure that fetch failures lead to rerun") {
206+
def testFetchFailures(initialSleepMillis: Int): Unit = {
194207
createWorkers(2)
195208
sc = createSparkContext(
196209
config.Tests.TEST_NO_STAGE_RETRY.key -> "false",
210+
"spark.executor.decommission.initial.sleep.millis" -> initialSleepMillis.toString,
197211
config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true")
198212
val executorIdToWorkerInfo = getExecutorToWorkerAssignments
199213
val executorToDecom = executorIdToWorkerInfo.keysIterator.next
@@ -212,22 +226,27 @@ class DecommissionWorkerSuite
212226
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
213227
val taskInfo = taskEnd.taskInfo
214228
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
215-
taskEnd.stageAttemptId == 0) {
229+
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
216230
decommissionWorkerOnMaster(workerToDecom,
217231
"decommission worker after task on it is done")
218232
}
219233
}
220234
}
221-
TestUtils.withListener(sc, listener) { _ =>
235+
withListener(sc, listener) { _ =>
222236
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
223237
val executorId = SparkEnv.get.executorId
224-
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
225-
Thread.sleep(sleepTimeSeconds * 1000L)
238+
val context = TaskContext.get()
239+
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
240+
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
241+
Thread.sleep(sleepTimeSeconds * 1000L)
242+
}
226243
List(1).iterator
227244
}, preservesPartitioning = true)
228245
.repartition(1).mapPartitions(iter => {
229246
val context = TaskContext.get()
230247
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
248+
// Wait a bit for the decommissioning to be triggered in the listener
249+
Thread.sleep(5000)
231250
// MapIndex is explicitly -1 to force the entire host to be decommissioned
232251
// However, this will cause both the tasks in the preceding stage since the host here is
233252
// "localhost" (shortcoming of this single-machine unit test in that all the workers
@@ -246,6 +265,14 @@ class DecommissionWorkerSuite
246265
assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen")
247266
}
248267

268+
test("decommission stalled workers ensure that fetch failures lead to rerun") {
269+
testFetchFailures(3600 * 1000)
270+
}
271+
272+
test("decommission eager workers ensure that fetch failures lead to rerun") {
273+
testFetchFailures(0)
274+
}
275+
249276
private abstract class RootStageAwareListener extends SparkListener {
250277
private var rootStageId: Option[Int] = None
251278
private val tasksFinished = new ConcurrentLinkedQueue[String]()
@@ -265,23 +292,31 @@ class DecommissionWorkerSuite
265292
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
266293
jobEnd.jobResult match {
267294
case JobSucceeded => jobDone.set(true)
295+
case JobFailed(exception) => logError(s"Job failed", exception)
268296
}
269297
}
270298

271299
protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}
272300

273301
protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}
274302

303+
private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
304+
String = {
305+
s"${stageId}:${stageAttemptId}:" +
306+
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
307+
}
308+
275309
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
310+
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
311+
logInfo(s"Task started: $signature")
276312
if (isRootStageId(taskStart.stageId)) {
277313
rootTasksStarted.add(taskStart.taskInfo)
278314
handleRootTaskStart(taskStart)
279315
}
280316
}
281317

282318
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
283-
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
284-
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
319+
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
285320
logInfo(s"Task End $taskSignature")
286321
tasksFinished.add(taskSignature)
287322
if (isRootStageId(taskEnd.stageId)) {
@@ -291,8 +326,13 @@ class DecommissionWorkerSuite
291326
}
292327

293328
def getTasksFinished(): Seq[String] = {
294-
assert(jobDone.get(), "Job isn't successfully done yet")
295-
tasksFinished.asScala.toSeq
329+
tasksFinished.asScala.toList
330+
}
331+
332+
def waitForJobDone(): Unit = {
333+
eventually(timeout(10.seconds), interval(100.milliseconds)) {
334+
assert(jobDone.get(), "Job isn't successfully done yet")
335+
}
296336
}
297337
}
298338

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.internal.config
3434
import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, TaskResourceRequests}
3535
import org.apache.spark.resource.ResourceUtils._
3636
import org.apache.spark.resource.TestResourceIDs._
37-
import org.apache.spark.util.ManualClock
37+
import org.apache.spark.util.{Clock, ManualClock, SystemClock}
3838

3939
class FakeSchedulerBackend extends SchedulerBackend {
4040
def start(): Unit = {}
@@ -88,10 +88,15 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
8888
}
8989

9090
def setupSchedulerWithMaster(master: String, confs: (String, String)*): TaskSchedulerImpl = {
91+
setupSchedulerWithMasterAndClock(master, new SystemClock, confs: _*)
92+
}
93+
94+
def setupSchedulerWithMasterAndClock(master: String, clock: Clock, confs: (String, String)*):
95+
TaskSchedulerImpl = {
9196
val conf = new SparkConf().setMaster(master).setAppName("TaskSchedulerImplSuite")
9297
confs.foreach { case (k, v) => conf.set(k, v) }
9398
sc = new SparkContext(conf)
94-
taskScheduler = new TaskSchedulerImpl(sc)
99+
taskScheduler = new TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES), clock = clock)
95100
setupHelper()
96101
}
97102

@@ -1802,9 +1807,10 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18021807
assert(2 == taskDescriptions.head.resources(GPU).addresses.size)
18031808
}
18041809

1805-
private def setupSchedulerForDecommissionTests(): TaskSchedulerImpl = {
1806-
val taskScheduler = setupSchedulerWithMaster(
1810+
private def setupSchedulerForDecommissionTests(clock: Clock): TaskSchedulerImpl = {
1811+
val taskScheduler = setupSchedulerWithMasterAndClock(
18071812
s"local[2]",
1813+
clock,
18081814
config.CPUS_PER_TASK.key -> 1.toString)
18091815
taskScheduler.submitTasks(FakeTask.createTaskSet(2))
18101816
val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1),
@@ -1815,7 +1821,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18151821
}
18161822

18171823
test("scheduler should keep the decommission info where host was decommissioned") {
1818-
val scheduler = setupSchedulerForDecommissionTests()
1824+
val scheduler = setupSchedulerForDecommissionTests(new SystemClock)
18191825

18201826
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false))
18211827
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true))
@@ -1829,8 +1835,9 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18291835
assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty)
18301836
}
18311837

1832-
test("scheduler should ignore decommissioning of removed executors") {
1833-
val scheduler = setupSchedulerForDecommissionTests()
1838+
test("scheduler should eventually purge removed and decommissioned executors") {
1839+
val clock = new ManualClock(10000L)
1840+
val scheduler = setupSchedulerForDecommissionTests(clock)
18341841

18351842
// executor 0 is decommissioned after loosing
18361843
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
@@ -1839,14 +1846,29 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18391846
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false))
18401847
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
18411848

1849+
assert(scheduler.executorsPendingDecommission.isEmpty)
1850+
clock.advance(5000)
1851+
18421852
// executor 1 is decommissioned before loosing
18431853
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
18441854
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
18451855
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
1856+
clock.advance(2000)
18461857
scheduler.executorLost("executor1", ExecutorExited(0, false, "normal"))
1847-
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1858+
assert(scheduler.decommissioningExecutorsToGc.size === 1)
1859+
assert(scheduler.executorsPendingDecommission.size === 1)
1860+
clock.advance(2000)
1861+
// It hasn't been 60 seconds yet before removal
1862+
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
18481863
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
1864+
clock.advance(2000)
1865+
assert(scheduler.decommissioningExecutorsToGc.size === 1)
1866+
assert(scheduler.executorsPendingDecommission.size === 1)
1867+
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
1868+
clock.advance(61000)
18491869
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1870+
assert(scheduler.decommissioningExecutorsToGc.isEmpty)
1871+
assert(scheduler.executorsPendingDecommission.isEmpty)
18501872
}
18511873

18521874
/**

0 commit comments

Comments
 (0)