diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2503ae0856dc7..6b376cdadc66b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1821,10 +1821,19 @@ private[spark] class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && - unRegisterOutputOnHostOnFetchFailure) { - // We had a fetch failure with the external shuffle service, so we - // assume all shuffle data on the node is bad. + val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled + val isHostDecommissioned = taskScheduler + .getExecutorDecommissionInfo(bmAddress.executorId) + .exists(_.isHostDecommissioned) + + // Shuffle output of all executors on host `bmAddress.host` may be lost if: + // - External shuffle service is enabled, so we assume that all shuffle data on node is + // bad. + // - Host is decommissioned, thus all executors on that host will die. + val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled || + isHostDecommissioned + val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost + && unRegisterOutputOnHostOnFetchFailure) { Some(bmAddress.host) } else { // Unregister shuffle data just for one executor (we don't have any @@ -2339,7 +2348,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ExecutorLost(execId, reason) => val workerLost = reason match { - case ExecutorProcessLost(_, true) => true + case ExecutorProcessLost(_, true, _) => true case _ => false } dagScheduler.handleExecutorLost(execId, workerLost) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 4141ed799a4e0..671dedaa5a6e8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -54,9 +54,14 @@ private [spark] object LossReasonPending extends ExecutorLossReason("Pending los /** * @param _message human readable loss reason * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service) + * @param causedByApp whether the loss of the executor is the fault of the running app. + * (assumed true by default unless known explicitly otherwise) */ private[spark] -case class ExecutorProcessLost(_message: String = "Worker lost", workerLost: Boolean = false) +case class ExecutorProcessLost( + _message: String = "Executor Process Lost", + workerLost: Boolean = false, + causedByApp: Boolean = true) extends ExecutorLossReason(_message) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index b29458c481413..1101d0616d2bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -103,6 +103,11 @@ private[spark] trait TaskScheduler { */ def executorDecommission(executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit + /** + * If an executor is decommissioned, return its corresponding decommission info + */ + def getExecutorDecommissionInfo(executorId: String): Option[ExecutorDecommissionInfo] + /** * Process a lost executor */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 510318afcb8df..b734d9f72944a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -136,6 +136,8 @@ private[spark] class TaskSchedulerImpl( // IDs of the tasks running on each executor private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] + private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo] + def runningTasksByExecutors: Map[String, Int] = synchronized { executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap } @@ -939,12 +941,43 @@ private[spark] class TaskSchedulerImpl( override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = { + synchronized { + // Don't bother noting decommissioning for executors that we don't know about + if (executorIdToHost.contains(executorId)) { + // The scheduler can get multiple decommission updates from multiple sources, + // and some of those can have isHostDecommissioned false. We merge them such that + // if we heard isHostDecommissioned ever true, then we keep that one since it is + // most likely coming from the cluster manager and thus authoritative + val oldDecomInfo = executorsPendingDecommission.get(executorId) + if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) { + executorsPendingDecommission(executorId) = decommissionInfo + } + } + } rootPool.executorDecommission(executorId) backend.reviveOffers() } - override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = { + override def getExecutorDecommissionInfo(executorId: String) + : Option[ExecutorDecommissionInfo] = synchronized { + executorsPendingDecommission.get(executorId) + } + + override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = { var failedExecutor: Option[String] = None + val reason = givenReason match { + // Handle executor process loss due to decommissioning + case ExecutorProcessLost(message, origWorkerLost, origCausedByApp) => + val executorDecommissionInfo = getExecutorDecommissionInfo(executorId) + ExecutorProcessLost( + message, + // Also mark the worker lost if we know that the host was decommissioned + origWorkerLost || executorDecommissionInfo.exists(_.isHostDecommissioned), + // Executor loss is certainly not caused by app if we knew that this executor is being + // decommissioned + causedByApp = executorDecommissionInfo.isEmpty && origCausedByApp) + case e => e + } synchronized { if (executorIdToRunningTaskIds.contains(executorId)) { @@ -1033,6 +1066,8 @@ private[spark] class TaskSchedulerImpl( } } + executorsPendingDecommission -= executorId + if (reason != LossReasonPending) { executorIdToHost -= executorId rootPool.executorLost(executorId, host, reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 4b31ff0c790da..d69f358cd19de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -985,6 +985,7 @@ private[spark] class TaskSetManager( val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp case ExecutorKilled => false + case ExecutorProcessLost(_, _, false) => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, diff --git a/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala new file mode 100644 index 0000000000000..ee9a6be03868f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState, WorkerDecommission} +import org.apache.spark.deploy.master.{ApplicationInfo, Master, WorkerInfo} +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.ExternalBlockHandler +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler._ +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +class DecommissionWorkerSuite + extends SparkFunSuite + with Logging + with LocalSparkContext + with BeforeAndAfterEach { + + private var masterAndWorkerConf: SparkConf = null + private var masterAndWorkerSecurityManager: SecurityManager = null + private var masterRpcEnv: RpcEnv = null + private var master: Master = null + private var workerIdToRpcEnvs: mutable.HashMap[String, RpcEnv] = null + private var workers: mutable.ArrayBuffer[Worker] = null + + override def beforeEach(): Unit = { + super.beforeEach() + masterAndWorkerConf = new SparkConf() + .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true) + masterAndWorkerSecurityManager = new SecurityManager(masterAndWorkerConf) + masterRpcEnv = RpcEnv.create( + Master.SYSTEM_NAME, + "localhost", + 0, + masterAndWorkerConf, + masterAndWorkerSecurityManager) + master = makeMaster() + workerIdToRpcEnvs = mutable.HashMap.empty + workers = mutable.ArrayBuffer.empty + } + + override def afterEach(): Unit = { + try { + masterRpcEnv.shutdown() + workerIdToRpcEnvs.values.foreach(_.shutdown()) + workerIdToRpcEnvs.clear() + master.stop() + workers.foreach(_.stop()) + workers.clear() + masterRpcEnv = null + } finally { + super.afterEach() + } + } + + test("decommission workers should not result in job failure") { + val maxTaskFailures = 2 + val numTimesToKillWorkers = maxTaskFailures + 1 + val numWorkers = numTimesToKillWorkers + 1 + createWorkers(numWorkers) + + // Here we will have a single task job and we will keep decommissioning (and killing) the + // worker running that task K times. Where K is more than the maxTaskFailures. Since the worker + // is notified of the decommissioning, the task failures can be ignored and not fail + // the job. + + sc = createSparkContext(config.TASK_MAX_FAILURES.key -> maxTaskFailures.toString) + val executorIdToWorkerInfo = getExecutorToWorkerAssignments + val taskIdsKilled = new ConcurrentHashMap[Long, Boolean] + val listener = new RootStageAwareListener { + override def handleRootTaskStart(taskStart: SparkListenerTaskStart): Unit = { + val taskInfo = taskStart.taskInfo + if (taskIdsKilled.size() < numTimesToKillWorkers) { + val workerInfo = executorIdToWorkerInfo(taskInfo.executorId) + decommissionWorkerOnMaster(workerInfo, "partition 0 must die") + killWorkerAfterTimeout(workerInfo, 1) + taskIdsKilled.put(taskInfo.taskId, true) + } + } + } + TestUtils.withListener(sc, listener) { _ => + val jobResult = sc.parallelize(1 to 1, 1).map { _ => + Thread.sleep(5 * 1000L); 1 + }.count() + assert(jobResult === 1) + } + // single task job that gets to run numTimesToKillWorkers + 1 times. + assert(listener.getTasksFinished().size === numTimesToKillWorkers + 1) + listener.rootTasksStarted.asScala.foreach { taskInfo => + assert(taskInfo.index == 0, s"Unknown task index ${taskInfo.index}") + } + listener.rootTasksEnded.asScala.foreach { taskInfo => + assert(taskInfo.index === 0, s"Expected task index ${taskInfo.index} to be 0") + // If a task has been killed then it shouldn't be successful + val taskSuccessExpected = !taskIdsKilled.getOrDefault(taskInfo.taskId, false) + val taskSuccessActual = taskInfo.successful + assert(taskSuccessActual === taskSuccessExpected, + s"Expected task success $taskSuccessActual == $taskSuccessExpected") + } + } + + test("decommission workers ensure that shuffle output is regenerated even with shuffle service") { + createWorkers(2) + val ss = new ExternalShuffleServiceHolder() + + sc = createSparkContext( + config.Tests.TEST_NO_STAGE_RETRY.key -> "true", + config.SHUFFLE_MANAGER.key -> "sort", + config.SHUFFLE_SERVICE_ENABLED.key -> "true", + config.SHUFFLE_SERVICE_PORT.key -> ss.getPort.toString + ) + + // Here we will create a 2 stage job: The first stage will have two tasks and the second stage + // will have one task. The two tasks in the first stage will be long and short. We decommission + // and kill the worker after the short task is done. Eventually the driver should get the + // executor lost signal for the short task executor. This should trigger regenerating + // the shuffle output since we cleanly decommissioned the executor, despite running with an + // external shuffle service. + try { + val executorIdToWorkerInfo = getExecutorToWorkerAssignments + val workerForTask0Decommissioned = new AtomicBoolean(false) + // single task job + val listener = new RootStageAwareListener { + override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val taskInfo = taskEnd.taskInfo + if (taskInfo.index == 0) { + if (workerForTask0Decommissioned.compareAndSet(false, true)) { + val workerInfo = executorIdToWorkerInfo(taskInfo.executorId) + decommissionWorkerOnMaster(workerInfo, "Kill early done map worker") + killWorkerAfterTimeout(workerInfo, 0) + logInfo(s"Killed the node ${workerInfo.hostPort} that was running the early task") + } + } + } + } + TestUtils.withListener(sc, listener) { _ => + val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => { + val sleepTimeSeconds = if (pid == 0) 1 else 10 + Thread.sleep(sleepTimeSeconds * 1000L) + List(1).iterator + }, preservesPartitioning = true).repartition(1).sum() + assert(jobResult === 2) + } + val tasksSeen = listener.getTasksFinished() + // 4 tasks: 2 from first stage, one retry due to decom, one more from the second stage. + assert(tasksSeen.size === 4, s"Expected 4 tasks but got $tasksSeen") + listener.rootTasksStarted.asScala.foreach { taskInfo => + assert(taskInfo.index <= 1, s"Expected ${taskInfo.index} <= 1") + assert(taskInfo.successful, s"Task ${taskInfo.index} should be successful") + } + val tasksEnded = listener.rootTasksEnded.asScala + tasksEnded.filter(_.index != 0).foreach { taskInfo => + assert(taskInfo.attemptNumber === 0, "2nd task should succeed on 1st attempt") + } + val firstTaskAttempts = tasksEnded.filter(_.index == 0) + assert(firstTaskAttempts.size > 1, s"Task 0 should have multiple attempts") + } finally { + ss.close() + } + } + + test("decommission workers ensure that fetch failures lead to rerun") { + createWorkers(2) + sc = createSparkContext( + config.Tests.TEST_NO_STAGE_RETRY.key -> "false", + config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true") + val executorIdToWorkerInfo = getExecutorToWorkerAssignments + val executorToDecom = executorIdToWorkerInfo.keysIterator.next + + // The task code below cannot call executorIdToWorkerInfo, so we need to pre-compute + // the worker to decom to force it to be serialized into the task. + val workerToDecom = executorIdToWorkerInfo(executorToDecom) + + // The setup of this job is similar to the one above: 2 stage job with first stage having + // long and short tasks. Except that we want the shuffle output to be regenerated on a + // fetch failure instead of an executor lost. Since it is hard to "trigger a fetch failure", + // we manually raise the FetchFailed exception when the 2nd stage's task runs and require that + // fetch failure to trigger a recomputation. + logInfo(s"Will try to decommission the task running on executor $executorToDecom") + val listener = new RootStageAwareListener { + override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val taskInfo = taskEnd.taskInfo + if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 && + taskEnd.stageAttemptId == 0) { + decommissionWorkerOnMaster(workerToDecom, + "decommission worker after task on it is done") + } + } + } + TestUtils.withListener(sc, listener) { _ => + val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => { + val executorId = SparkEnv.get.executorId + val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1 + Thread.sleep(sleepTimeSeconds * 1000L) + List(1).iterator + }, preservesPartitioning = true) + .repartition(1).mapPartitions(iter => { + val context = TaskContext.get() + if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) { + // MapIndex is explicitly -1 to force the entire host to be decommissioned + // However, this will cause both the tasks in the preceding stage since the host here is + // "localhost" (shortcoming of this single-machine unit test in that all the workers + // are actually on the same host) + throw new FetchFailedException(BlockManagerId(executorToDecom, + workerToDecom.host, workerToDecom.port), 0, 0, -1, 0, "Forcing fetch failure") + } + val sumVal: List[Int] = List(iter.sum) + sumVal.iterator + }, preservesPartitioning = true) + .sum() + assert(jobResult === 2) + } + // 6 tasks: 2 from first stage, 2 rerun again from first stage, 2nd stage attempt 1 and 2. + val tasksSeen = listener.getTasksFinished() + assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen") + } + + private abstract class RootStageAwareListener extends SparkListener { + private var rootStageId: Option[Int] = None + private val tasksFinished = new ConcurrentLinkedQueue[String]() + private val jobDone = new AtomicBoolean(false) + val rootTasksStarted = new ConcurrentLinkedQueue[TaskInfo]() + val rootTasksEnded = new ConcurrentLinkedQueue[TaskInfo]() + + protected def isRootStageId(stageId: Int): Boolean = + (rootStageId.isDefined && rootStageId.get == stageId) + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + if (stageSubmitted.stageInfo.parentIds.isEmpty && rootStageId.isEmpty) { + rootStageId = Some(stageSubmitted.stageInfo.stageId) + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobEnd.jobResult match { + case JobSucceeded => jobDone.set(true) + } + } + + protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {} + + protected def handleRootTaskStart(start: SparkListenerTaskStart) = {} + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + if (isRootStageId(taskStart.stageId)) { + rootTasksStarted.add(taskStart.taskInfo) + handleRootTaskStart(taskStart) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" + + s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}" + logInfo(s"Task End $taskSignature") + tasksFinished.add(taskSignature) + if (isRootStageId(taskEnd.stageId)) { + rootTasksEnded.add(taskEnd.taskInfo) + handleRootTaskEnd(taskEnd) + } + } + + def getTasksFinished(): Seq[String] = { + assert(jobDone.get(), "Job isn't successfully done yet") + tasksFinished.asScala.toSeq + } + } + + private def getExecutorToWorkerAssignments: Map[String, WorkerInfo] = { + val executorIdToWorkerInfo = mutable.HashMap[String, WorkerInfo]() + master.workers.foreach { wi => + assert(wi.executors.size <= 1, "There should be at most one executor per worker") + // Cast the executorId to string since the TaskInfo.executorId is a string + wi.executors.values.foreach { e => + val executorIdString = e.id.toString + val oldWorkerInfo = executorIdToWorkerInfo.put(executorIdString, wi) + assert(oldWorkerInfo.isEmpty, + s"Executor $executorIdString already present on another worker ${oldWorkerInfo}") + } + } + executorIdToWorkerInfo.toMap + } + + private def makeMaster(): Master = { + val master = new Master( + masterRpcEnv, + masterRpcEnv.address, + 0, + masterAndWorkerSecurityManager, + masterAndWorkerConf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + private def createWorkers(numWorkers: Int, cores: Int = 1, memory: Int = 1024): Unit = { + val workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create( + Worker.SYSTEM_NAME + i, + "localhost", + 0, + masterAndWorkerConf, + masterAndWorkerSecurityManager) + } + workers.clear() + val rpcAddressToRpcEnv: mutable.HashMap[RpcAddress, RpcEnv] = mutable.HashMap.empty + workerRpcEnvs.foreach { rpcEnv => + val workDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.ENDPOINT_NAME, workDir, masterAndWorkerConf, masterAndWorkerSecurityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + workers.append(worker) + val oldRpcEnv = rpcAddressToRpcEnv.put(rpcEnv.address, rpcEnv) + logInfo(s"Created a worker at ${rpcEnv.address} with workdir $workDir") + assert(oldRpcEnv.isEmpty, s"Detected duplicate rpcEnv ${oldRpcEnv} for ${rpcEnv.address}") + } + workerIdToRpcEnvs.clear() + // Wait until all workers register with master successfully + eventually(timeout(1.minute), interval(1.seconds)) { + val workersOnMaster = getMasterState.workers + val numWorkersCurrently = workersOnMaster.length + logInfo(s"Waiting for $numWorkers workers to come up: So far $numWorkersCurrently") + assert(numWorkersCurrently === numWorkers) + workersOnMaster.foreach { workerInfo => + val rpcAddress = RpcAddress(workerInfo.host, workerInfo.port) + val rpcEnv = rpcAddressToRpcEnv(rpcAddress) + assert(rpcEnv != null, s"Cannot find the worker for $rpcAddress") + val oldRpcEnv = workerIdToRpcEnvs.put(workerInfo.id, rpcEnv) + assert(oldRpcEnv.isEmpty, s"Detected duplicate rpcEnv ${oldRpcEnv} for worker " + + s"${workerInfo.id}") + } + } + logInfo(s"Created ${workers.size} workers") + } + + private def getMasterState: MasterStateResponse = { + master.self.askSync[MasterStateResponse](RequestMasterState) + } + + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + + def decommissionWorkerOnMaster(workerInfo: WorkerInfo, reason: String): Unit = { + logInfo(s"Trying to decommission worker ${workerInfo.id} for reason `$reason`") + master.self.send(WorkerDecommission(workerInfo.id, workerInfo.endpoint)) + } + + def killWorkerAfterTimeout(workerInfo: WorkerInfo, secondsToWait: Int): Unit = { + val env = workerIdToRpcEnvs(workerInfo.id) + Thread.sleep(secondsToWait * 1000L) + env.shutdown() + env.awaitTermination() + } + + def createSparkContext(extraConfs: (String, String)*): SparkContext = { + val conf = new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .setAll(extraConfs) + sc = new SparkContext(conf) + val appId = sc.applicationId + eventually(timeout(1.minute), interval(1.seconds)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + sc + } + + private class ExternalShuffleServiceHolder() { + // The external shuffle service can start with default configs and not get polluted by the + // other configs used in this test. + private val transportConf = SparkTransportConf.fromSparkConf(new SparkConf(), + "shuffle", numUsableCores = 2) + private val rpcHandler = new ExternalBlockHandler(transportConf, null) + private val transportContext = new TransportContext(transportConf, rpcHandler) + private val server = transportContext.createServer() + + def getPort: Int = server.getPort + + def close(): Unit = { + Utils.tryLogNonFatalError { + server.close() + } + Utils.tryLogNonFatalError { + rpcHandler.close() + } + Utils.tryLogNonFatalError { + transportContext.close() + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 45af0d086890f..c829006923c4f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -178,6 +178,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionInfo( + executorId: String): Option[ExecutorDecommissionInfo] = None } /** @@ -785,6 +787,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionInfo( + executorId: String): Option[ExecutorDecommissionInfo] = None } val noKillScheduler = new DAGScheduler( sc, diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index b2a5f77b4b04c..07d88672290fc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -101,4 +101,6 @@ private class DummyTaskScheduler extends TaskScheduler { override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionInfo( + executorId: String): Option[ExecutorDecommissionInfo] = None } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ca3ce9d43ca5..e5836458e7f91 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1802,6 +1802,53 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(2 == taskDescriptions.head.resources(GPU).addresses.size) } + private def setupSchedulerForDecommissionTests(): TaskSchedulerImpl = { + val taskScheduler = setupSchedulerWithMaster( + s"local[2]", + config.CPUS_PER_TASK.key -> 1.toString) + taskScheduler.submitTasks(FakeTask.createTaskSet(2)) + val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1), + WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten + assert(taskDescriptions.map(_.executorId).sorted === Seq("executor0", "executor1")) + taskScheduler + } + + test("scheduler should keep the decommission info where host was decommissioned") { + val scheduler = setupSchedulerForDecommissionTests() + + scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false)) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true)) + scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0 new", false)) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1 new", false)) + + assert(scheduler.getExecutorDecommissionInfo("executor0") + === Some(ExecutorDecommissionInfo("0 new", false))) + assert(scheduler.getExecutorDecommissionInfo("executor1") + === Some(ExecutorDecommissionInfo("1", true))) + assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty) + } + + test("scheduler should ignore decommissioning of removed executors") { + val scheduler = setupSchedulerForDecommissionTests() + + // executor 0 is decommissioned after loosing + assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty) + scheduler.executorLost("executor0", ExecutorExited(0, false, "normal")) + assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty) + scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false)) + assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty) + + // executor 1 is decommissioned before loosing + assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false)) + assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined) + scheduler.executorLost("executor1", ExecutorExited(0, false, "normal")) + assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false)) + assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty) + } + /** * Used by tests to simulate a task failure. This calls the failure handler explicitly, to ensure * that all the state is updated when this method returns. Otherwise, there's no way to know when