diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 3e07e38036a01..ad2e699ebe765 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -37,8 +37,8 @@ private[streaming] class WriteAheadLogBackedBlockRDDPartition( val index: Int, val blockId: BlockId, - val segment: WriteAheadLogFileSegment - ) extends Partition + val segment: WriteAheadLogFileSegment) + extends Partition /** @@ -59,11 +59,11 @@ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, @transient hadoopConfig: Configuration, - @transient override val blockIds: Array[BlockId], + @transient blockIds: Array[BlockId], @transient val segments: Array[WriteAheadLogFileSegment], val storeInBlockManager: Boolean, - val storageLevel: StorageLevel - ) extends BlockRDD[T](sc, blockIds) { + val storageLevel: StorageLevel) + extends BlockRDD[T](sc, blockIds) { require( blockIds.length == segments.length, @@ -76,7 +76,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() Array.tabulate(blockIds.size) { i => - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) } + new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) + } } /** @@ -117,8 +118,14 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - lazy val segmentLocations = HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig) - blockLocations.orElse(segmentLocations).getOrElse(Seq.empty) + getBlockIdLocations().get(partition.blockId) match { + case Some(locations) => locations // BlockManager has the data + case None => + // Block Manager does not have data, so find the HDFS data nodes which have the data + // If we can't find the HDFS locations, just return empty + val segment = partition.segment + HdfsUtils.getFileSegmentLocations( + segment.path, segment.offset, segment.length, hadoopConfig).getOrElse(Seq.empty) + } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 23aebb53811d9..15b7c5b1fd98c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -83,10 +83,9 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll { private def testRDD( numPartitionssInBM: Int, numPartitionsInWAL: Int, - testStoreInBM: Boolean = false - ) { + testStoreInBM: Boolean = false) { val numBlocks = numPartitionssInBM + numPartitionsInWAL - val data = Seq.tabulate(numBlocks) { _ => Seq.fill(10) { scala.util.Random.nextString(50) } } + val data = Seq.fill(numBlocks, 10)(scala.util.Random.nextString(50)) // Put the necessary blocks in the block manager val blockIds = Array.fill(numBlocks)(StreamBlockId(Random.nextInt(), Random.nextInt()))