Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.scheduler.cluster

import java.net.URL
import java.util.concurrent.atomic.AtomicReference
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import scala.language.reflectiveCalls
Expand All @@ -32,15 +33,35 @@ import org.apache.spark.ui.TestFilter

class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext {

private var yarnSchedulerBackend: YarnSchedulerBackend = _

override def afterEach() {
try {
if (yarnSchedulerBackend != null) {
yarnSchedulerBackend.stop()
}
} finally {
super.afterEach()
}
}

test("RequestExecutors reflects node blacklist and is serializable") {
sc = new SparkContext("local", "YarnSchedulerBackendSuite")
val sched = mock[TaskSchedulerImpl]
when(sched.sc).thenReturn(sc)
val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) {
// Subclassing the TaskSchedulerImpl here instead of using Mockito. For details see SPARK-26891.
val sched = new TaskSchedulerImpl(sc) {
val blacklistedNodes = new AtomicReference[Set[String]]()

def setNodeBlacklist(nodeBlacklist: Set[String]): Unit = blacklistedNodes.set(nodeBlacklist)

override def nodeBlacklist(): Set[String] = blacklistedNodes.get()
}

val yarnSchedulerBackendExtended = new YarnSchedulerBackend(sched, sc) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the extra variable here, vs assigning to yarnSchedulerBackend? I don't see that they are used separately.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed because of the different type: the yarnSchedulerBackend type is YarnSchedulerBackend but yarnSchedulerBackendExtended type is an anonim subclass of YarnSchedulerBackend with the extra def setNodeBlacklist. On yarnSchedulerBackend I cannot call this extra method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so then how is it assigned in the next line? a subclass of YarnSchedulerBackend is still assignable to YarnSchedulerBackend. I might be missing something obvious here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is assignable as yarnSchedulerBackendExtended is an instance of YarnSchedulerBackend too, although not a direct one.

def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = {
this.hostToLocalTaskCount = hostToLocalTaskCount
}
}
yarnSchedulerBackend = yarnSchedulerBackendExtended
val ser = new JavaSerializer(sc.conf).newInstance()
for {
blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c"))
Expand All @@ -50,9 +71,9 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc
Map("a" -> 1, "b" -> 2)
)
} {
yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount)
when(sched.nodeBlacklist()).thenReturn(blacklist)
val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested)
yarnSchedulerBackendExtended.setHostToLocalTaskCount(hostToLocalCount)
sched.setNodeBlacklist(blacklist)
val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numRequested)
assert(req.requestedTotal === numRequested)
assert(req.nodeBlacklist === blacklist)
assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty)
Expand All @@ -75,9 +96,9 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc
// Before adding the "YARN" filter, should get the code from the filter in SparkConf.
assert(TestUtils.httpResponseCode(url) === HttpServletResponse.SC_BAD_GATEWAY)

val backend = new YarnSchedulerBackend(sched, sc) { }
yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) { }

backend.addWebUIFilter(classOf[TestFilter2].getName(),
yarnSchedulerBackend.addWebUIFilter(classOf[TestFilter2].getName(),
Map("responseCode" -> HttpServletResponse.SC_NOT_ACCEPTABLE.toString), "")

sc.ui.get.getHandlers.foreach { h =>
Expand Down