@@ -26,6 +26,7 @@ import scala.util.Random
2626
2727import org .apache .spark .{Logging , SparkConf , SparkEnv , SparkException }
2828import org .apache .spark .io .CompressionCodec
29+ import org .apache .spark .serializer .Serializer
2930import org .apache .spark .storage .{BroadcastBlockId , StorageLevel }
3031import org .apache .spark .util .ByteBufferInputStream
3132import org .apache .spark .util .io .ByteArrayChunkOutputStream
@@ -46,14 +47,12 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream
4647 * This prevents the driver from being the bottleneck in sending out multiple copies of the
4748 * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast ]].
4849 *
50+ * When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
51+ *
4952 * @param obj object to broadcast
50- * @param isLocal whether Spark is running in local mode (single JVM process).
5153 * @param id A unique identifier for the broadcast variable.
5254 */
53- private [spark] class TorrentBroadcast [T : ClassTag ](
54- obj : T ,
55- @ transient private val isLocal : Boolean ,
56- id : Long )
55+ private [spark] class TorrentBroadcast [T : ClassTag ](obj : T , id : Long )
5756 extends Broadcast [T ](id) with Logging with Serializable {
5857
5958 /**
@@ -62,6 +61,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](
6261 * blocks from the driver and/or other executors.
6362 */
6463 @ transient private var _value : T = obj
64+ /** The compression codec to use, or None if compression is disabled */
65+ @ transient private var compressionCodec : Option [CompressionCodec ] = _
66+ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
67+ @ transient private var blockSize : Int = _
68+
69+ private def setConf (conf : SparkConf ) {
70+ compressionCodec = if (conf.getBoolean(" spark.broadcast.compress" , true )) {
71+ Some (CompressionCodec .createCodec(conf))
72+ } else {
73+ None
74+ }
75+ blockSize = conf.getInt(" spark.broadcast.blockSize" , 4096 ) * 1024
76+ }
77+ setConf(SparkEnv .get.conf)
6578
6679 private val broadcastId = BroadcastBlockId (id)
6780
@@ -76,23 +89,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](
7689 * @return number of blocks this broadcast variable is divided into
7790 */
7891 private def writeBlocks (): Int = {
79- // For local mode, just put the object in the BlockManager so we can find it later.
80- SparkEnv .get.blockManager.putSingle(
81- broadcastId, _value, StorageLevel .MEMORY_AND_DISK , tellMaster = false )
82-
83- if (! isLocal) {
84- val blocks = TorrentBroadcast .blockifyObject(_value)
85- blocks.zipWithIndex.foreach { case (block, i) =>
86- SparkEnv .get.blockManager.putBytes(
87- BroadcastBlockId (id, " piece" + i),
88- block,
89- StorageLevel .MEMORY_AND_DISK_SER ,
90- tellMaster = true )
91- }
92- blocks.length
93- } else {
94- 0
92+ // Store a copy of the broadcast variable in the driver so that tasks run on the driver
93+ // do not create a duplicate copy of the broadcast variable's value.
94+ SparkEnv .get.blockManager.putSingle(broadcastId, _value, StorageLevel .MEMORY_AND_DISK ,
95+ tellMaster = false )
96+ val blocks =
97+ TorrentBroadcast .blockifyObject(_value, blockSize, SparkEnv .get.serializer, compressionCodec)
98+ blocks.zipWithIndex.foreach { case (block, i) =>
99+ SparkEnv .get.blockManager.putBytes(
100+ BroadcastBlockId (id, " piece" + i),
101+ block,
102+ StorageLevel .MEMORY_AND_DISK_SER ,
103+ tellMaster = true )
95104 }
105+ blocks.length
96106 }
97107
98108 /** Fetch torrent blocks from the driver and/or other executors. */
@@ -104,29 +114,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](
104114
105115 for (pid <- Random .shuffle(Seq .range(0 , numBlocks))) {
106116 val pieceId = BroadcastBlockId (id, " piece" + pid)
107-
108- // First try getLocalBytes because there is a chance that previous attempts to fetch the
117+ logDebug( s " Reading piece $pieceId of $broadcastId " )
118+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
109119 // broadcast blocks have already fetched some of the blocks. In that case, some blocks
110120 // would be available locally (on this executor).
111- var blockOpt = bm.getLocalBytes(pieceId)
112- if (! blockOpt.isDefined) {
113- blockOpt = bm.getRemoteBytes(pieceId)
114- blockOpt match {
115- case Some (block) =>
116- // If we found the block from remote executors/driver's BlockManager, put the block
117- // in this executor's BlockManager.
118- SparkEnv .get.blockManager.putBytes(
119- pieceId,
120- block,
121- StorageLevel .MEMORY_AND_DISK_SER ,
122- tellMaster = true )
123-
124- case None =>
125- throw new SparkException (" Failed to get " + pieceId + " of " + broadcastId)
126- }
121+ def getLocal : Option [ByteBuffer ] = bm.getLocalBytes(pieceId)
122+ def getRemote : Option [ByteBuffer ] = bm.getRemoteBytes(pieceId).map { block =>
123+ // If we found the block from remote executors/driver's BlockManager, put the block
124+ // in this executor's BlockManager.
125+ SparkEnv .get.blockManager.putBytes(
126+ pieceId,
127+ block,
128+ StorageLevel .MEMORY_AND_DISK_SER ,
129+ tellMaster = true )
130+ block
127131 }
128- // If we get here, the option is defined.
129- blocks(pid) = blockOpt.get
132+ val block : ByteBuffer = getLocal.orElse(getRemote).getOrElse(
133+ throw new SparkException (s " Failed to get $pieceId of $broadcastId" ))
134+ blocks(pid) = block
130135 }
131136 blocks
132137 }
@@ -156,6 +161,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](
156161 private def readObject (in : ObjectInputStream ) {
157162 in.defaultReadObject()
158163 TorrentBroadcast .synchronized {
164+ setConf(SparkEnv .get.conf)
159165 SparkEnv .get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
160166 case Some (x) =>
161167 _value = x.asInstanceOf [T ]
@@ -167,7 +173,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](
167173 val time = (System .nanoTime() - start) / 1e9
168174 logInfo(" Reading broadcast variable " + id + " took " + time + " s" )
169175
170- _value = TorrentBroadcast .unBlockifyObject[T ](blocks)
176+ _value =
177+ TorrentBroadcast .unBlockifyObject[T ](blocks, SparkEnv .get.serializer, compressionCodec)
171178 // Store the merged copy in BlockManager so other tasks on this executor don't
172179 // need to re-fetch it.
173180 SparkEnv .get.blockManager.putSingle(
@@ -179,43 +186,29 @@ private[spark] class TorrentBroadcast[T: ClassTag](
179186
180187
181188private object TorrentBroadcast extends Logging {
182- /** Size of each block. Default value is 4MB. */
183- private lazy val BLOCK_SIZE = conf.getInt(" spark.broadcast.blockSize" , 4096 ) * 1024
184- private var initialized = false
185- private var conf : SparkConf = null
186- private var compress : Boolean = false
187- private var compressionCodec : CompressionCodec = null
188-
189- def initialize (_isDriver : Boolean , conf : SparkConf ) {
190- TorrentBroadcast .conf = conf // TODO: we might have to fix it in tests
191- synchronized {
192- if (! initialized) {
193- compress = conf.getBoolean(" spark.broadcast.compress" , true )
194- compressionCodec = CompressionCodec .createCodec(conf)
195- initialized = true
196- }
197- }
198- }
199189
200- def stop () {
201- initialized = false
202- }
203-
204- def blockifyObject [ T : ClassTag ]( obj : T ): Array [ByteBuffer ] = {
205- val bos = new ByteArrayChunkOutputStream (BLOCK_SIZE )
206- val out : OutputStream = if (compress) compressionCodec .compressedOutputStream(bos) else bos
207- val ser = SparkEnv .get. serializer.newInstance()
190+ def blockifyObject [ T : ClassTag ](
191+ obj : T ,
192+ blockSize : Int ,
193+ serializer : Serializer ,
194+ compressionCodec : Option [ CompressionCodec ] ): Array [ByteBuffer ] = {
195+ val bos = new ByteArrayChunkOutputStream (blockSize )
196+ val out : OutputStream = compressionCodec.map(c => c .compressedOutputStream(bos)).getOrElse( bos)
197+ val ser = serializer.newInstance()
208198 val serOut = ser.serializeStream(out)
209199 serOut.writeObject[T ](obj).close()
210200 bos.toArrays.map(ByteBuffer .wrap)
211201 }
212202
213- def unBlockifyObject [T : ClassTag ](blocks : Array [ByteBuffer ]): T = {
203+ def unBlockifyObject [T : ClassTag ](
204+ blocks : Array [ByteBuffer ],
205+ serializer : Serializer ,
206+ compressionCodec : Option [CompressionCodec ]): T = {
207+ require(blocks.nonEmpty, " Cannot unblockify an empty array of blocks" )
214208 val is = new SequenceInputStream (
215209 asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream (block))))
216- val in : InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
217-
218- val ser = SparkEnv .get.serializer.newInstance()
210+ val in : InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
211+ val ser = serializer.newInstance()
219212 val serIn = ser.deserializeStream(in)
220213 val obj = serIn.readObject[T ]()
221214 serIn.close()
@@ -227,6 +220,7 @@ private object TorrentBroadcast extends Logging {
227220 * If removeFromDriver is true, also remove these persisted blocks on the driver.
228221 */
229222 def unpersist (id : Long , removeFromDriver : Boolean , blocking : Boolean ) = {
223+ logDebug(s " Unpersisting TorrentBroadcast $id" )
230224 SparkEnv .get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
231225 }
232226}
0 commit comments