Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.scheduler.cluster

import java.util.concurrent.atomic.AtomicReference

import scala.language.reflectiveCalls

import org.mockito.Mockito.when
Expand All @@ -27,15 +29,35 @@ import org.apache.spark.serializer.JavaSerializer

class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext {

private var yarnSchedulerBackend: YarnSchedulerBackend = _

override def afterEach(): Unit = {
try {
if (yarnSchedulerBackend != null) {
yarnSchedulerBackend.stop()
}
} finally {
super.afterEach()
}
}

test("RequestExecutors reflects node blacklist and is serializable") {
sc = new SparkContext("local", "YarnSchedulerBackendSuite")
val sched = mock[TaskSchedulerImpl]
when(sched.sc).thenReturn(sc)
val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) {
// Subclassing the TaskSchedulerImpl here instead of using Mockito. For details see SPARK-26891.
val sched = new TaskSchedulerImpl(sc) {
val blacklistedNodes = new AtomicReference[Set[String]]()

def setNodeBlacklist(nodeBlacklist: Set[String]): Unit = blacklistedNodes.set(nodeBlacklist)

override def nodeBlacklist(): Set[String] = blacklistedNodes.get()
}

val yarnSchedulerBackendExtended = new YarnSchedulerBackend(sched, sc) {
def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = {
this.hostToLocalTaskCount = hostToLocalTaskCount
}
}
yarnSchedulerBackend = yarnSchedulerBackendExtended
val ser = new JavaSerializer(sc.conf).newInstance()
for {
blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c"))
Expand All @@ -45,16 +67,15 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc
Map("a" -> 1, "b" -> 2)
)
} {
yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount)
when(sched.nodeBlacklist()).thenReturn(blacklist)
val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested)
yarnSchedulerBackendExtended.setHostToLocalTaskCount(hostToLocalCount)
sched.setNodeBlacklist(blacklist)
val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numRequested)
assert(req.requestedTotal === numRequested)
assert(req.nodeBlacklist === blacklist)
assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty)
// Serialize to make sure serialization doesn't throw an error
ser.serialize(req)
}
sc.stop()
}

}