Skip to content

Commit 6846e40

Browse files
committed
Add test for flatMapWith()
1 parent 6c124a9 commit 6846e40

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,9 @@ abstract class RDD[T: ClassTag](
742742
(constructA: Int => A, preservesPartitioning: Boolean = false)
743743
(f: (T, A) => U): RDD[U] = withScope {
744744
val cleanF = sc.clean(f)
745+
val cleanA = sc.clean(constructA)
745746
mapPartitionsWithIndex((index, iter) => {
746-
val a = constructA(index)
747+
val a = cleanA(index)
747748
iter.map(t => cleanF(t, a))
748749
}, preservesPartitioning)
749750
}
@@ -758,8 +759,9 @@ abstract class RDD[T: ClassTag](
758759
(constructA: Int => A, preservesPartitioning: Boolean = false)
759760
(f: (T, A) => Seq[U]): RDD[U] = withScope {
760761
val cleanF = sc.clean(f)
762+
val cleanA = sc.clean(constructA)
761763
mapPartitionsWithIndex((index, iter) => {
762-
val a = constructA(index)
764+
val a = cleanA(index)
763765
iter.flatMap(t => cleanF(t, a))
764766
}, preservesPartitioning)
765767
}
@@ -772,8 +774,9 @@ abstract class RDD[T: ClassTag](
772774
@deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
773775
def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope {
774776
val cleanF = sc.clean(f)
777+
val cleanA = sc.clean(constructA)
775778
mapPartitionsWithIndex { (index, iter) =>
776-
val a = constructA(index)
779+
val a = cleanA(index)
777780
iter.map(t => {cleanF(t, a); t})
778781
}
779782
}

core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class ClosureCleanerSuite extends FunSuite {
9292
expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) }
9393
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) }
9494
expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) }
95+
expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) }
9596
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) }
9697
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) }
9798
expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) }
@@ -260,6 +261,16 @@ private object TestUserClosuresActuallyCleaned {
260261
def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = {
261262
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
262263
}
264+
def testFlatMapWith(rdd: RDD[Int]): Unit = {
265+
import java.util.Random
266+
val randoms = rdd.flatMapWith(
267+
(index: Int) => new Random(index + 42))
268+
{(t: Int, prng: Random) =>
269+
val random = prng.nextDouble()
270+
Seq(random * t, random * t * 10)}.
271+
count()
272+
rdd.mapPartitionsWithIndex { (_, it) => return; it }.count()
273+
}
263274
def testZipPartitions2(rdd: RDD[Int]): Unit = {
264275
rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count()
265276
}

0 commit comments

Comments
 (0)