Skip to content

Commit c284d9f

Browse files
committed
fix a racing condition in zipWithIndex
1 parent 825709a commit c284d9f

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
3939
private[spark]
4040
class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) {
4141

42-
override def getPartitions: Array[Partition] = {
42+
/** The start index of each partition. */
43+
@transient private val startIndices: Array[Long] = {
4344
val n = prev.partitions.size
44-
val startIndices: Array[Long] =
45-
if (n == 0) {
46-
Array[Long]()
47-
} else if (n == 1) {
48-
Array(0L)
49-
} else {
50-
prev.context.runJob(
51-
prev,
52-
Utils.getIteratorSize _,
53-
0 until n - 1, // do not need to count the last partition
54-
false
55-
).scanLeft(0L)(_ + _)
56-
}
45+
if (n == 0) {
46+
Array[Long]()
47+
} else if (n == 1) {
48+
Array(0L)
49+
} else {
50+
prev.context.runJob(
51+
prev,
52+
Utils.getIteratorSize _,
53+
0 until n - 1, // do not need to count the last partition
54+
allowLocal = false
55+
).scanLeft(0L)(_ + _)
56+
}
57+
}
58+
59+
override def getPartitions: Array[Partition] = {
5760
firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
5861
}
5962

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
739739
}
740740
}
741741

742+
test("zipWithIndex chained with other RDDs (SPARK-4433)") {
743+
val count = sc.parallelize(0 until 10, 2).zipWithIndex().repartition(4).count()
744+
assert(count === 10)
745+
}
746+
742747
test("zipWithUniqueId") {
743748
val n = 10
744749
val data = sc.parallelize(0 until n, 3)

0 commit comments

Comments
 (0)