@@ -24,6 +24,7 @@ import java.lang.ref.WeakReference
2424import  java .util .concurrent .ConcurrentHashMap 
2525
2626import  org .apache .spark .Logging 
27+ import  java .util .concurrent .atomic .AtomicInteger 
2728
2829private [util] case  class  TimeStampedWeakValue [T ](timestamp : Long , weakValue : WeakReference [T ]) {
2930  def  this (timestamp : Long , value : T ) =  this (timestamp, new  WeakReference [T ](value))
@@ -44,6 +45,12 @@ private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: Wea
4445private [spark] class  TimeStampedWeakValueHashMap [A , B ]()
4546  extends  WrappedJavaHashMap [A , B , A , TimeStampedWeakValue [B ]] with  Logging  {
4647
48+   /**  Number of inserts after which keys whose weak ref values are null will be cleaned */  
49+   private  val  CLEANUP_INTERVAL  =  1000 
50+ 
51+   /**  Counter for counting the number of inserts */  
52+   private  val  insertCounts  =  new  AtomicInteger (0 )
53+ 
4754  protected [util] val  internalJavaMap :  util.Map [A , TimeStampedWeakValue [B ]] =  {
4855    new  ConcurrentHashMap [A , TimeStampedWeakValue [B ]]()
4956  }
@@ -52,11 +59,21 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
5259    new  TimeStampedWeakValueHashMap [K1 , V1 ]()
5360  }
5461
62+   override  def  += (kv : (A , B )):  this .type  =  {
63+     //  Cleanup null value at certain intervals
64+     if  (insertCounts.incrementAndGet() %  CLEANUP_INTERVAL  ==  0 ) {
65+       cleanNullValues()
66+     }
67+     super .+= (kv)
68+   }
69+ 
5570  override  def  get (key : A ):  Option [B ] =  {
5671    Option (internalJavaMap.get(key)) match  {
5772      case  Some (weakValue) => 
5873        val  value  =  weakValue.weakValue.get
59-         if  (value ==  null ) cleanupKey(key)
74+         if  (value ==  null ) {
75+           internalJavaMap.remove(key)
76+         }
6077        Option (value)
6178      case  None  => 
6279        None 
@@ -72,16 +89,10 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
7289  }
7390
7491  override  def  iterator :  Iterator [(A , B )] =  {
75-     val  jIterator  =  internalJavaMap.entrySet().iterator()
76-     JavaConversions .asScalaIterator(jIterator).flatMap(kv =>  {
77-       val  key  =  kv.getKey
78-       val  value  =  kv.getValue.weakValue.get
79-       if  (value ==  null ) {
80-         cleanupKey(key)
81-         Seq .empty
82-       } else  {
83-         Seq ((key, value))
84-       }
92+     val  iterator  =  internalJavaMap.entrySet().iterator()
93+     JavaConversions .asScalaIterator(iterator).flatMap(kv =>  {
94+       val  (key, value) =  (kv.getKey, kv.getValue.weakValue.get)
95+       if  (value !=  null ) Seq ((key, value)) else  Seq .empty
8596    })
8697  }
8798
@@ -104,8 +115,18 @@ private[spark] class TimeStampedWeakValueHashMap[A, B]()
104115    }
105116  }
106117
107-   private  def  cleanupKey (key : A ) {
108-     //  TODO: Consider cleaning up keys to empty weak ref values automatically in future.
118+   /**  
119+    * Removes keys whose weak referenced values have become null. 
120+    */  
121+   private  def  cleanNullValues () {
122+     val  iterator  =  internalJavaMap.entrySet().iterator()
123+     while  (iterator.hasNext) {
124+       val  entry  =  iterator.next()
125+       if  (entry.getValue.weakValue.get ==  null ) {
126+         logDebug(" Removing key " +  entry.getKey)
127+         iterator.remove()
128+       }
129+     }
109130  }
110131
111132  private  def  currentTime  =  System .currentTimeMillis()
0 commit comments