Skip to content

Commit 1e752f1

Browse files
author
Roman Pastukhov
committed
Added unpersist method to Broadcast.
1 parent 9209287 commit 1e752f1

File tree

8 files changed

+163
-38
lines changed

8 files changed

+163
-38
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,13 @@ class SparkContext(
613613
* Broadcast a read-only variable to the cluster, returning a
614614
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
615615
* The variable will be sent to each cluster only once.
616+
*
617+
* If `registerBlocks` is true, workers will notify driver about blocks they create
618+
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
616619
*/
617-
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
620+
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
621+
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
622+
}
618623

619624
/**
620625
* Add a file to be downloaded with this Spark job on every node.

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ import org.apache.spark._
5353
abstract class Broadcast[T](val id: Long) extends Serializable {
5454
def value: T
5555

56+
/**
57+
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
58+
*
59+
* @param removeSource Whether to remove data from disk as well.
60+
* Will cause errors if broadcast is accessed on workers afterwards
61+
* (e.g. in case of RDD re-computation due to executor failure).
62+
*/
63+
def unpersist(removeSource: Boolean = false)
64+
5665
// We cannot have an abstract readObject here due to some weird issues with
5766
// readObject having to be 'private' in sub-classes.
5867

@@ -91,8 +100,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging
91100

92101
private val nextBroadcastId = new AtomicLong(0)
93102

94-
def newBroadcast[T](value_ : T, isLocal: Boolean) =
95-
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
103+
def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
104+
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)
96105

97106
def isDriver = _isDriver
98107
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ import org.apache.spark.SparkConf
2727
*/
2828
trait BroadcastFactory {
2929
def initialize(isDriver: Boolean, conf: SparkConf): Unit
30-
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
30+
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
3131
def stop(): Unit
3232
}

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,20 @@ import org.apache.spark.io.CompressionCodec
2929
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
3030
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
3131

32-
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
32+
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
3333
extends Broadcast[T](id) with Logging with Serializable {
3434

3535
def value = value_
3636

37+
def unpersist(removeSource: Boolean) {
38+
SparkEnv.get.blockManager.master.removeBlock(blockId)
39+
SparkEnv.get.blockManager.removeBlock(blockId)
40+
41+
if (removeSource) {
42+
HttpBroadcast.cleanupById(id)
43+
}
44+
}
45+
3746
def blockId = BroadcastBlockId(id)
3847

3948
HttpBroadcast.synchronized {
@@ -54,7 +63,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
5463
logInfo("Started reading broadcast variable " + id)
5564
val start = System.nanoTime
5665
value_ = HttpBroadcast.read[T](id)
57-
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
66+
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
5867
val time = (System.nanoTime - start) / 1e9
5968
logInfo("Reading broadcast variable " + id + " took " + time + " s")
6069
}
@@ -69,8 +78,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
6978
class HttpBroadcastFactory extends BroadcastFactory {
7079
def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
7180

72-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
73-
new HttpBroadcast[T](value_, isLocal, id)
81+
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
82+
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)
7483

7584
def stop() { HttpBroadcast.stop() }
7685
}
@@ -132,8 +141,10 @@ private object HttpBroadcast extends Logging {
132141
logInfo("Broadcast server started at " + serverUri)
133142
}
134143

144+
def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
145+
135146
def write(id: Long, value: Any) {
136-
val file = new File(broadcastDir, BroadcastBlockId(id).name)
147+
val file = getFile(id)
137148
val out: OutputStream = {
138149
if (compress) {
139150
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -167,20 +178,30 @@ private object HttpBroadcast extends Logging {
167178
obj
168179
}
169180

181+
def deleteFile(fileName: String) {
182+
try {
183+
new File(fileName).delete()
184+
logInfo("Deleted broadcast file '" + fileName + "'")
185+
} catch {
186+
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
187+
}
188+
}
189+
170190
def cleanup(cleanupTime: Long) {
171191
val iterator = files.internalMap.entrySet().iterator()
172192
while(iterator.hasNext) {
173193
val entry = iterator.next()
174194
val (file, time) = (entry.getKey, entry.getValue)
175195
if (time < cleanupTime) {
176-
try {
177-
iterator.remove()
178-
new File(file.toString).delete()
179-
logInfo("Deleted broadcast file '" + file + "'")
180-
} catch {
181-
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
182-
}
196+
iterator.remove()
197+
deleteFile(file)
183198
}
184199
}
185200
}
201+
202+
def cleanupById(id: Long) {
203+
val file = getFile(id).getAbsolutePath
204+
files.internalMap.remove(file)
205+
deleteFile(file)
206+
}
186207
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,36 @@ import scala.math
2323
import scala.util.Random
2424

2525
import org.apache.spark._
26-
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
26+
import org.apache.spark.storage.{BlockId, BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
2727
import org.apache.spark.util.Utils
2828

2929

30-
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
30+
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
3131
extends Broadcast[T](id) with Logging with Serializable {
3232

3333
def value = value_
3434

35+
def unpersist(removeSource: Boolean) {
36+
SparkEnv.get.blockManager.master.removeBlock(broadcastId)
37+
SparkEnv.get.blockManager.removeBlock(broadcastId)
38+
39+
if (removeSource) {
40+
for (pid <- pieceIds) {
41+
SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid))
42+
}
43+
SparkEnv.get.blockManager.removeBlock(metaId)
44+
} else {
45+
for (pid <- pieceIds) {
46+
SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid))
47+
}
48+
SparkEnv.get.blockManager.dropFromMemory(metaId)
49+
}
50+
}
51+
3552
def broadcastId = BroadcastBlockId(id)
53+
private def metaId = BroadcastHelperBlockId(broadcastId, "meta")
54+
private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid)
55+
private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList
3656

3757
TorrentBroadcast.synchronized {
3858
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -55,7 +75,6 @@ extends Broadcast[T](id) with Logging with Serializable {
5575
hasBlocks = tInfo.totalBlocks
5676

5777
// Store meta-info
58-
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
5978
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
6079
TorrentBroadcast.synchronized {
6180
SparkEnv.get.blockManager.putSingle(
@@ -64,7 +83,7 @@ extends Broadcast[T](id) with Logging with Serializable {
6483

6584
// Store individual pieces
6685
for (i <- 0 until totalBlocks) {
67-
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
86+
val pieceId = pieceBlockId(i)
6887
TorrentBroadcast.synchronized {
6988
SparkEnv.get.blockManager.putSingle(
7089
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
@@ -94,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable {
94113
// This creates a tradeoff between memory usage and latency.
95114
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
96115
SparkEnv.get.blockManager.putSingle(
97-
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
116+
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
98117

99118
// Remove arrayOfBlocks from memory once value_ is on local cache
100119
resetWorkerVariables()
@@ -109,6 +128,11 @@ extends Broadcast[T](id) with Logging with Serializable {
109128
}
110129

111130
private def resetWorkerVariables() {
131+
if (arrayOfBlocks != null) {
132+
for (pid <- pieceIds) {
133+
SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid))
134+
}
135+
}
112136
arrayOfBlocks = null
113137
totalBytes = -1
114138
totalBlocks = -1
@@ -117,7 +141,6 @@ extends Broadcast[T](id) with Logging with Serializable {
117141

118142
def receiveBroadcast(variableID: Long): Boolean = {
119143
// Receive meta-info
120-
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
121144
var attemptId = 10
122145
while (attemptId > 0 && totalBlocks == -1) {
123146
TorrentBroadcast.synchronized {
@@ -140,9 +163,9 @@ extends Broadcast[T](id) with Logging with Serializable {
140163
}
141164

142165
// Receive actual blocks
143-
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
166+
val recvOrder = new Random().shuffle(pieceIds)
144167
for (pid <- recvOrder) {
145-
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
168+
val pieceId = pieceBlockId(pid)
146169
TorrentBroadcast.synchronized {
147170
SparkEnv.get.blockManager.getSingle(pieceId) match {
148171
case Some(x) =>
@@ -243,8 +266,8 @@ class TorrentBroadcastFactory extends BroadcastFactory {
243266

244267
def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
245268

246-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
247-
new TorrentBroadcast[T](value_, isLocal, id)
269+
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
270+
new TorrentBroadcast[T](value_, isLocal, id, registerBlocks)
248271

249272
def stop() { TorrentBroadcast.stop() }
250273
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ private[spark] class BlockManager(
196196
}
197197
}
198198

199+
/**
200+
* For testing. Returns number of blocks BlockManager knows about that are in memory.
201+
*/
202+
def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_))
203+
199204
/**
200205
* Get storage level of local block. If no info exists for the block, then returns null.
201206
*/
@@ -720,6 +725,13 @@ private[spark] class BlockManager(
720725
}
721726

722727
/**
728+
* Drop a block from memory, possibly putting it on disk if applicable.
729+
*/
730+
def dropFromMemory(blockId: BlockId) {
731+
memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId)
732+
}
733+
734+
/**
723735
* Remove all blocks belonging to the given RDD.
724736
* @return The number of blocks removed.
725737
*/

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,24 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
182182
}
183183
}
184184

185+
/**
186+
* Drop a block from memory, possibly putting it on disk if applicable.
187+
*/
188+
def dropFromMemory(blockId: BlockId) {
189+
val entry = entries.synchronized { entries.get(blockId) }
190+
// This should never be null as only one thread should be dropping
191+
// blocks and removing entries. However the check is still here for
192+
// future safety.
193+
if (entry != null) {
194+
val data = if (entry.deserialized) {
195+
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
196+
} else {
197+
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
198+
}
199+
blockManager.dropFromMemory(blockId, data)
200+
}
201+
}
202+
185203
/**
186204
* Tries to free up a given amount of space to store a particular block, but can fail and return
187205
* false if either the block is bigger than our memory or it would require replacing another
@@ -227,18 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
227245
if (maxMemory - (currentMemory - selectedMemory) >= space) {
228246
logInfo(selectedBlocks.size + " blocks selected for dropping")
229247
for (blockId <- selectedBlocks) {
230-
val entry = entries.synchronized { entries.get(blockId) }
231-
// This should never be null as only one thread should be dropping
232-
// blocks and removing entries. However the check is still here for
233-
// future safety.
234-
if (entry != null) {
235-
val data = if (entry.deserialized) {
236-
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
237-
} else {
238-
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
239-
}
240-
blockManager.dropFromMemory(blockId, data)
241-
}
248+
dropFromMemory(blockId)
242249
}
243250
return true
244251
} else {

core/src/test/scala/org/apache/spark/BroadcastSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
package org.apache.spark
1919

2020
import org.scalatest.FunSuite
21+
import org.scalatest.concurrent.Timeouts._
22+
import org.scalatest.time.{Millis, Span}
23+
import org.scalatest.concurrent.Eventually._
24+
import org.scalatest.time.SpanSugar._
25+
import org.scalatest.matchers.ShouldMatchers._
2126

2227
class BroadcastSuite extends FunSuite with LocalSparkContext {
2328

@@ -82,4 +87,47 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
8287
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
8388
}
8489

90+
def blocksExist(sc: SparkContext, numSlaves: Int) = {
91+
val rdd = sc.parallelize(1 to numSlaves, numSlaves)
92+
val workerBlocks = rdd.mapPartitions(_ => {
93+
val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory()
94+
Seq(blocks).iterator
95+
})
96+
val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory()
97+
98+
totalKnown > 0
99+
}
100+
101+
def testUnpersist(bcFactory: String, removeSource: Boolean) {
102+
test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) {
103+
val numSlaves = 2
104+
System.setProperty("spark.broadcast.factory", bcFactory)
105+
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
106+
val list = List(1, 2, 3, 4)
107+
108+
assert(!blocksExist(sc, numSlaves))
109+
110+
val listBroadcast = sc.broadcast(list, true)
111+
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
112+
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
113+
114+
assert(blocksExist(sc, numSlaves))
115+
116+
listBroadcast.unpersist(removeSource)
117+
118+
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
119+
blocksExist(sc, numSlaves) should be (false)
120+
}
121+
122+
if (!removeSource) {
123+
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
124+
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
125+
}
126+
}
127+
}
128+
129+
for (removeSource <- Seq(true, false)) {
130+
testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource)
131+
testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource)
132+
}
85133
}

0 commit comments

Comments
 (0)