Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 43 additions & 70 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.scheduler

import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
Expand All @@ -28,8 +29,6 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import akka.actor._
import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout

Expand All @@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat

/**
Expand Down Expand Up @@ -67,8 +66,6 @@ class DAGScheduler(
clock: Clock = SystemClock)
extends Logging {

import DAGScheduler._

def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
Expand Down Expand Up @@ -112,42 +109,31 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]

private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))

// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()

private[scheduler] var eventProcessActor: ActorRef = _

/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)

/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)

private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
implicit val timeout = Timeout(30 seconds)
val initEventActorReply =
dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
eventProcessActor = Await.result(initEventActorReply, timeout.duration).
asInstanceOf[ActorRef]
}
private val messageScheduler =
Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))

initializeEventProcessActor()
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)

// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessActor ! BeginEvent(task, taskInfo)
eventProcessLoop.post(BeginEvent(task, taskInfo))
}

// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(taskInfo: TaskInfo) {
eventProcessActor ! GettingResultEvent(taskInfo)
eventProcessLoop.post(GettingResultEvent(taskInfo))
}

// Called by TaskScheduler to report task completions or failures.
Expand All @@ -158,7 +144,8 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
eventProcessLoop.post(
CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}

/**
Expand All @@ -180,18 +167,18 @@ class DAGScheduler(

// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
eventProcessActor ! ExecutorLost(execId)
eventProcessLoop.post(ExecutorLost(execId))
}

// Called by TaskScheduler when a host is added
def executorAdded(execId: String, host: String) {
eventProcessActor ! ExecutorAdded(execId, host)
eventProcessLoop.post(ExecutorAdded(execId, host))
}

// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
eventProcessActor ! TaskSetFailed(taskSet, reason)
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}

private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
Expand Down Expand Up @@ -496,8 +483,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
waiter
}

Expand Down Expand Up @@ -537,8 +524,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
eventProcessLoop.post(JobSubmitted(
jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}

Expand All @@ -547,19 +534,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
eventProcessActor ! JobCancelled(jobId)
eventProcessLoop.post(JobCancelled(jobId))
}

def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
eventProcessActor ! JobGroupCancelled(groupId)
eventProcessLoop.post(JobGroupCancelled(groupId))
}

/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
eventProcessActor ! AllJobsCancelled
eventProcessLoop.post(AllJobsCancelled)
}

private[scheduler] def doCancelAllJobs() {
Expand All @@ -575,7 +562,7 @@ class DAGScheduler(
* Cancel all jobs associated with a running or scheduled stage.
*/
def cancelStage(stageId: Int) {
eventProcessActor ! StageCancelled(stageId)
eventProcessLoop.post(StageCancelled(stageId))
}

/**
Expand Down Expand Up @@ -1059,16 +1046,15 @@ class DAGScheduler(

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty && eventProcessActor != null) {
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
import env.actorSystem.dispatcher
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
env.actorSystem.scheduler.scheduleOnce(
RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
Expand Down Expand Up @@ -1326,40 +1312,21 @@ class DAGScheduler(

def stop() {
logInfo("Stopping DAGScheduler")
dagSchedulerActorSupervisor ! PoisonPill
eventProcessLoop.stop()
taskScheduler.stop()
}
}

private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
extends Actor with Logging {

override val supervisorStrategy =
OneForOneStrategy() {
case x: Exception =>
logError("eventProcesserActor failed; shutting down SparkContext", x)
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
Stop
}

def receive = {
case p: Props => sender ! context.actorOf(p)
case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
}
// Start the event thread at the end of the constructor
eventProcessLoop.start()
}

private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
extends Actor with Logging {
private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {

/**
* The main event loop of the DAG scheduler.
*/
def receive = {
override def onReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
listener, properties)
Expand Down Expand Up @@ -1398,7 +1365,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.resubmitFailedStages()
}

override def postStop() {
override def onError(e: Throwable): Unit = {
logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
}

override def onStop() {
// Cancel any active jobs in postStop hook
dagScheduler.cleanUpAfterSchedulerStop()
}
Expand All @@ -1408,9 +1385,5 @@ private[spark] object DAGScheduler {
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 200.milliseconds

// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
val RESUBMIT_TIMEOUT = 200
}
124 changes: 124 additions & 0 deletions core/src/main/scala/org/apache/spark/util/EventLoop.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque}

import scala.util.control.NonFatal

import org.apache.spark.Logging

/**
* An event loop to receive events from the caller and process all events in the event thread. It
* will start an exclusive event thread to process all events.
*
* Note: The event queue will grow indefinitely. So subclasses should make sure `onReceive` can
* handle events in time to avoid the potential OOM.
*/
private[spark] abstract class EventLoop[E](name: String) extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark as DeveloperAPI?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a private class. We only need to mark developer API's for exposed classes.


private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()

private val stopped = new AtomicBoolean(false)

private val eventThread = new Thread(name) {
setDaemon(true)

override def run(): Unit = {
try {
while (!stopped.get) {
val event = eventQueue.take()
try {
onReceive(event)
} catch {
case NonFatal(e) => {
try {
onError(e)
} catch {
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
}
}
}
} catch {
case ie: InterruptedException => // exit even if eventQueue is not empty
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
}

}

def start(): Unit = {
if (stopped.get) {
throw new IllegalStateException(name + " has already been stopped")
}
// Call onStart before starting the event thread to make sure it happens before onReceive
onStart()
eventThread.start()
}

def stop(): Unit = {
if (stopped.compareAndSet(false, true)) {
eventThread.interrupt()
eventThread.join()
// Call onStop after the event thread exits to make sure onReceive happens before onStop
onStop()
} else {
// Keep quiet to allow calling `stop` multiple times.
}
}

/**
* Put the event into the event queue. The event thread will process it later.
*/
def post(event: E): Unit = {
eventQueue.put(event)
}

/**
* Return if the event thread has already been started but not yet stopped.
*/
def isActive: Boolean = eventThread.isAlive

/**
* Invoked when `start()` is called but before the event thread starts.
*/
protected def onStart(): Unit = {}

/**
* Invoked when `stop()` is called and the event thread exits.
*/
protected def onStop(): Unit = {}

/**
* Invoked in the event thread when polling events from the event queue.
*
* Note: Should avoid calling blocking actions in `onReceive`, or the event thread will be blocked
* and cannot process events in time. If you want to call some blocking actions, run them in
* another thread.
*/
protected def onReceive(event: E): Unit

/**
* Invoked if `onReceive` throws any non fatal error. Any non fatal error thrown from `onError`
* will be ignored.
*/
protected def onError(e: Throwable): Unit

}
Loading