Skip to content

Commit d49912b

Browse files
zsxwingnemccarthy
authored andcommitted
[SPARK-7989] [CORE] [TESTS] Fix flaky tests in ExternalShuffleServiceSuite and SparkListenerWithClusterSuite
The flaky tests in ExternalShuffleServiceSuite and SparkListenerWithClusterSuite will fail if there are not enough executors up before running the jobs. This PR adds `JobProgressListener.waitUntilExecutorsUp`. The tests for the cluster mode can use it to wait until the expected executors are up. Author: zsxwing <[email protected]> Closes apache#6546 from zsxwing/SPARK-7989 and squashes the following commits: 5560e09 [zsxwing] Fix a typo 3b69840 [zsxwing] Fix flaky tests in ExternalShuffleServiceSuite and SparkListenerWithClusterSuite
1 parent e9dfeb4 commit d49912b

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
package org.apache.spark.ui.jobs
1919

20+
import java.util.concurrent.TimeoutException
21+
2022
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
2123

24+
import com.google.common.annotations.VisibleForTesting
25+
2226
import org.apache.spark._
2327
import org.apache.spark.annotation.DeveloperApi
2428
import org.apache.spark.executor.TaskMetrics
@@ -526,4 +530,30 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
526530
override def onApplicationStart(appStarted: SparkListenerApplicationStart) {
527531
startTime = appStarted.time
528532
}
533+
534+
/**
535+
* For testing only. Wait until at least `numExecutors` executors are up, or throw
536+
* `TimeoutException` if the waiting time elapsed before `numExecutors` executors up.
537+
*
538+
* @param numExecutors the number of executors to wait at least
539+
* @param timeout time to wait in milliseconds
540+
*/
541+
@VisibleForTesting
542+
private[spark] def waitUntilExecutorsUp(numExecutors: Int, timeout: Long): Unit = {
543+
val finishTime = System.currentTimeMillis() + timeout
544+
while (System.currentTimeMillis() < finishTime) {
545+
val numBlockManagers = synchronized {
546+
blockManagerIds.size
547+
}
548+
if (numBlockManagers >= numExecutors + 1) {
549+
// Need to count the block manager in driver
550+
return
551+
}
552+
// Sleep rather than using wait/notify, because this is used only for testing and wait/notify
553+
// add overhead in the general case.
554+
Thread.sleep(10)
555+
}
556+
throw new TimeoutException(
557+
s"Can't find $numExecutors executors before $timeout milliseconds elapsed")
558+
}
529559
}

core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
5555
sc.env.blockManager.externalShuffleServiceEnabled should equal(true)
5656
sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient])
5757

58+
// In a slow machine, one slave may register hundreds of milliseconds ahead of the other one.
59+
// If we don't wait for all salves, it's possible that only one executor runs all jobs. Then
60+
// all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch
61+
// local blocks from the local BlockManager and won't send requests to ExternalShuffleService.
62+
// In this case, we won't receive FetchFailed. And it will make this test fail.
63+
// Therefore, we should wait until all salves are up
64+
sc.jobProgressListener.waitUntilExecutorsUp(2, 10000)
65+
5866
val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _)
5967

6068
rdd.count()

core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.broadcast
1919

20-
import scala.concurrent.duration._
2120
import scala.util.Random
2221

2322
import org.scalatest.Assertions
24-
import org.scalatest.concurrent.Eventually._
2523

2624
import org.apache.spark._
2725
import org.apache.spark.io.SnappyCompressionCodec
@@ -312,13 +310,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext {
312310
val _sc =
313311
new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
314312
// Wait until all salves are up
315-
eventually(timeout(10.seconds), interval(10.milliseconds)) {
316-
_sc.jobProgressListener.synchronized {
317-
val numBlockManagers = _sc.jobProgressListener.blockManagerIds.size
318-
assert(numBlockManagers == numSlaves + 1,
319-
s"Expect ${numSlaves + 1} block managers, but was ${numBlockManagers}")
320-
}
321-
}
313+
_sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000)
322314
_sc
323315
} else {
324316
new SparkContext("local", "test", broadcastConf)

core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.scheduler
1919

20-
import org.apache.spark.scheduler.cluster.ExecutorInfo
21-
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
20+
import scala.collection.mutable
2221

2322
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
2423

25-
import scala.collection.mutable
24+
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
25+
import org.apache.spark.scheduler.cluster.ExecutorInfo
2626

2727
/**
2828
* Unit tests for SparkListener that require a local cluster.
@@ -41,6 +41,10 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext
4141
val listener = new SaveExecutorInfo
4242
sc.addSparkListener(listener)
4343

44+
// This test will check if the number of executors received by "SparkListener" is same as the
45+
// number of all executors, so we need to wait until all executors are up
46+
sc.jobProgressListener.waitUntilExecutorsUp(2, 10000)
47+
4448
val rdd1 = sc.parallelize(1 to 100, 4)
4549
val rdd2 = rdd1.map(_.toString)
4650
rdd2.setName("Target RDD")

0 commit comments

Comments
 (0)