Skip to content

Commit e4e4e2b

Browse files
attilapirosMarcelo Vanzin
authored andcommitted
[SPARK-26891][YARN] Fixing flaky test in YarnSchedulerBackendSuite
The test "RequestExecutors reflects node blacklist and is serializable" is flaky because of multi threaded access of the mock task scheduler. For details check [Mockito FAQ (occasional exceptions like: WrongTypeOfReturnValue)](https://github.com/mockito/mockito/wiki/FAQ#is-mockito-thread-safe). So instead of mocking the task scheduler in the test TaskSchedulerImpl is simply subclassed. This multithreaded access of the `nodeBlacklist()` method is coming from: 1) the unit test thread via calling of the method `prepareRequestExecutors()` 2) the `DriverEndpoint.onStart` which runs a periodic task that ends up calling this method Existing unittest. Closes #23801 from attilapiros/SPARK-26891. Authored-by: “attilapiros” <[email protected]> Signed-off-by: Marcelo Vanzin <[email protected]>
1 parent 885aa55 commit e4e4e2b

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.spark.scheduler.cluster
1818

1919
import java.net.URL
20+
import java.util.concurrent.atomic.AtomicReference
2021
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
2122

2223
import scala.language.reflectiveCalls
@@ -32,15 +33,35 @@ import org.apache.spark.ui.TestFilter
3233

3334
class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext {
3435

36+
private var yarnSchedulerBackend: YarnSchedulerBackend = _
37+
38+
override def afterEach(): Unit = {
39+
try {
40+
if (yarnSchedulerBackend != null) {
41+
yarnSchedulerBackend.stop()
42+
}
43+
} finally {
44+
super.afterEach()
45+
}
46+
}
47+
3548
test("RequestExecutors reflects node blacklist and is serializable") {
3649
sc = new SparkContext("local", "YarnSchedulerBackendSuite")
37-
val sched = mock[TaskSchedulerImpl]
38-
when(sched.sc).thenReturn(sc)
39-
val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) {
50+
// Subclassing the TaskSchedulerImpl here instead of using Mockito. For details see SPARK-26891.
51+
val sched = new TaskSchedulerImpl(sc) {
52+
val blacklistedNodes = new AtomicReference[Set[String]]()
53+
54+
def setNodeBlacklist(nodeBlacklist: Set[String]): Unit = blacklistedNodes.set(nodeBlacklist)
55+
56+
override def nodeBlacklist(): Set[String] = blacklistedNodes.get()
57+
}
58+
59+
val yarnSchedulerBackendExtended = new YarnSchedulerBackend(sched, sc) {
4060
def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = {
4161
this.hostToLocalTaskCount = hostToLocalTaskCount
4262
}
4363
}
64+
yarnSchedulerBackend = yarnSchedulerBackendExtended
4465
val ser = new JavaSerializer(sc.conf).newInstance()
4566
for {
4667
blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c"))
@@ -50,9 +71,9 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc
5071
Map("a" -> 1, "b" -> 2)
5172
)
5273
} {
53-
yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount)
54-
when(sched.nodeBlacklist()).thenReturn(blacklist)
55-
val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested)
74+
yarnSchedulerBackendExtended.setHostToLocalTaskCount(hostToLocalCount)
75+
sched.setNodeBlacklist(blacklist)
76+
val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numRequested)
5677
assert(req.requestedTotal === numRequested)
5778
assert(req.nodeBlacklist === blacklist)
5879
assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty)
@@ -75,9 +96,9 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc
7596
// Before adding the "YARN" filter, should get the code from the filter in SparkConf.
7697
assert(TestUtils.httpResponseCode(url) === HttpServletResponse.SC_BAD_GATEWAY)
7798

78-
val backend = new YarnSchedulerBackend(sched, sc) { }
99+
yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) { }
79100

80-
backend.addWebUIFilter(classOf[TestFilter2].getName(),
101+
yarnSchedulerBackend.addWebUIFilter(classOf[TestFilter2].getName(),
81102
Map("responseCode" -> HttpServletResponse.SC_NOT_ACCEPTABLE.toString), "")
82103

83104
sc.ui.get.getHandlers.foreach { h =>

0 commit comments

Comments
 (0)