@@ -280,15 +280,24 @@ object AccumulatorParam {
280280
281281// TODO: The multi-thread support in accumulators is kind of lame; check
282282// if there's a more intuitive way of doing it right
283- private [spark] object Accumulators {
284- // Store a WeakReference instead of a StrongReference because this way accumulators can be
285- // appropriately garbage collected during long-running jobs and release memory
286- type WeakAcc = WeakReference [Accumulable [_, _]]
287- val originals = Map [Long , WeakAcc ]()
288- val localAccums = new ThreadLocal [Map [Long , WeakAcc ]]() {
289- override protected def initialValue () = Map [Long , WeakAcc ]()
283+ private [spark] object Accumulators extends Logging {
284+ /**
285+ * This global map holds the original accumulator objects that are created on the driver.
286+ * It keeps weak references to these objects so that accumulators can be garbage-collected
287+ * once the RDDs and user-code that reference them are cleaned up.
288+ */
289+ val originals = Map [Long , WeakReference [Accumulable [_, _]]]()
290+
291+ /**
292+ * This thread-local map holds per-task copies of accumulators; it is used to collect the set
293+ * of accumulator updates to send back to the driver when tasks complete. After tasks complete,
294+ * this map is cleared by `Accumulators.clear()` (see Executor.scala).
295+ */
296+ private val localAccums = new ThreadLocal [Map [Long , Accumulable [_, _]]]() {
297+ override protected def initialValue () = Map [Long , Accumulable [_, _]]()
290298 }
291- var lastId : Long = 0
299+
300+ private var lastId : Long = 0
292301
293302 def newId (): Long = synchronized {
294303 lastId += 1
@@ -297,16 +306,16 @@ private[spark] object Accumulators {
297306
298307 def register (a : Accumulable [_, _], original : Boolean ): Unit = synchronized {
299308 if (original) {
300- originals(a.id) = new WeakAcc (a)
309+ originals(a.id) = new WeakReference [ Accumulable [_, _]] (a)
301310 } else {
302- localAccums.get()(a.id) = new WeakAcc (a)
311+ localAccums.get()(a.id) = a
303312 }
304313 }
305314
306315 // Clear the local (non-original) accumulators for the current thread
307316 def clear () {
308317 synchronized {
309- localAccums.get.clear
318+ localAccums.get.clear()
310319 }
311320 }
312321
@@ -320,12 +329,7 @@ private[spark] object Accumulators {
320329 def values : Map [Long , Any ] = synchronized {
321330 val ret = Map [Long , Any ]()
322331 for ((id, accum) <- localAccums.get) {
323- // Since we are now storing weak references, we must check whether the underlying data
324- // is valid.
325- ret(id) = accum.get match {
326- case Some (values) => values.localValue
327- case None => None
328- }
332+ ret(id) = accum.localValue
329333 }
330334 return ret
331335 }
@@ -341,6 +345,8 @@ private[spark] object Accumulators {
341345 case None =>
342346 throw new IllegalAccessError (" Attempted to access garbage collected Accumulator." )
343347 }
348+ } else {
349+ logWarning(s " Ignoring accumulator update for unknown accumulator id $id" )
344350 }
345351 }
346352 }
0 commit comments