@@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging {
403403    }
404404    newSentences.unpersist()
405405
406-     val  word2VecMap  =  mutable.HashMap .empty[String , Array [Float ]]
407-     var  i  =  0 
408-     while  (i <  vocabSize) {
409-       val  word  =  bcVocab.value(i).word
410-       val  vector  =  new  Array [Float ](vectorSize)
411-       Array .copy(syn0Global, i *  vectorSize, vector, 0 , vectorSize)
412-       word2VecMap +=  word ->  vector
413-       i +=  1 
414-     }
415- 
416-     new  Word2VecModel (word2VecMap.toMap)
406+     val  wordArray  =  vocab.map(_.word)
407+     new  Word2VecModel (wordArray.zipWithIndex.toMap, syn0Global)
417408  }
418409
419410  /**  
@@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging {
429420/** 
430421 * :: Experimental :: 
431422 * Word2Vec model 
423+  * @param  wordIndex  maps each word to an index, which can retrieve the corresponding 
424+  *                  vector from wordVectors 
425+  * @param  wordVectors  array of length numWords * vectorSize, vector corresponding 
426+  *                    to the word mapped with index i can be retrieved by the slice 
427+  *                    (i * vectorSize, i * vectorSize + vectorSize) 
432428 */  
433429@ Experimental 
434- class  Word2VecModel  private [spark] (
435-     model : Map [String , Array [Float ]]) extends  Serializable  with  Saveable  {
436- 
437-   //  wordList: Ordered list of words obtained from model.
438-   private  val  wordList :  Array [String ] =  model.keys.toArray
439- 
440-   //  wordIndex: Maps each word to an index, which can retrieve the corresponding
441-   //             vector from wordVectors (see below).
442-   private  val  wordIndex :  Map [String , Int ] =  wordList.zip(0  until model.size).toMap
430+ class  Word2VecModel  private [mllib] (
431+     private  val  wordIndex :  Map [String , Int ],
432+     private  val  wordVectors :  Array [Float ]) extends  Serializable  with  Saveable  {
443433
444-   //  vectorSize: Dimension of each word's vector.
445-   private  val  vectorSize  =  model.head._2.size
446434  private  val  numWords  =  wordIndex.size
435+   //  vectorSize: Dimension of each word's vector.
436+   private  val  vectorSize  =  wordVectors.length /  numWords
437+ 
438+   //  wordList: Ordered list of words obtained from wordIndex.
439+   private  val  wordList :  Array [String ] =  {
440+     val  (wl, _) =  wordIndex.toSeq.sortBy(_._2).unzip
441+     wl.toArray
442+   }
447443
448-   //  wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
449-   //               mapped with index i can be retrieved by the slice
450-   //               (ind * vectorSize, ind * vectorSize + vectorSize)
451444  //  wordVecNorms: Array of length numWords, each value being the Euclidean norm
452445  //                of the wordVector.
453-   private  val  (wordVectors : Array [Float ], wordVecNorms : Array [Double ]) =  {
454-     val  wordVectors  =  new  Array [Float ](vectorSize *  numWords)
446+   private  val  wordVecNorms :  Array [Double ] =  {
455447    val  wordVecNorms  =  new  Array [Double ](numWords)
456448    var  i  =  0 
457449    while  (i <  numWords) {
458-       val  vec  =  model.get(wordList(i)).get
459-       Array .copy(vec, 0 , wordVectors, i *  vectorSize, vectorSize)
450+       val  vec  =  wordVectors.slice(i *  vectorSize, i *  vectorSize +  vectorSize)
460451      wordVecNorms(i) =  blas.snrm2(vectorSize, vec, 1 )
461452      i +=  1 
462453    }
463-     (wordVectors, wordVecNorms)
454+     wordVecNorms
455+   }
456+ 
457+   def  this (model : Map [String , Array [Float ]]) =  {
458+     this (Word2VecModel .buildWordIndex(model), Word2VecModel .buildWordVectors(model))
464459  }
465460
466461  private  def  cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]):  Double  =  {
@@ -484,8 +479,9 @@ class Word2VecModel private[spark] (
484479   * @return  vector representation of word 
485480   */  
486481  def  transform (word : String ):  Vector  =  {
487-     model.get(word) match  {
488-       case  Some (vec) => 
482+     wordIndex.get(word) match  {
483+       case  Some (ind) => 
484+         val  vec  =  wordVectors.slice(ind *  vectorSize, ind *  vectorSize +  vectorSize)
489485        Vectors .dense(vec.map(_.toDouble))
490486      case  None  => 
491487        throw  new  IllegalStateException (s " $word not in vocabulary " )
@@ -511,7 +507,7 @@ class Word2VecModel private[spark] (
511507   */  
512508  def  findSynonyms (vector : Vector , num : Int ):  Array [(String , Double )] =  {
513509    require(num >  0 , " Number of similar words should > 0" 
514- 
510+      //  TODO: optimize top-k 
515511    val  fVector  =  vector.toArray.map(_.toFloat)
516512    val  cosineVec  =  Array .fill[Float ](numWords)(0 )
517513    val  alpha :  Float  =  1 
@@ -521,13 +517,13 @@ class Word2VecModel private[spark] (
521517      " T" 1 , beta, cosineVec, 1 )
522518
523519    //  Need not divide with the norm of the given vector since it is constant.
524-     val  updatedCosines  =  new   Array [ Double ](numWords )
520+     val  cosVec  =  cosineVec.map(_.toDouble )
525521    var  ind  =  0 
526522    while  (ind <  numWords) {
527-       updatedCosines (ind) =  cosineVec(ind)  /  wordVecNorms(ind)
523+       cosVec (ind) /=  wordVecNorms(ind)
528524      ind +=  1 
529525    }
530-     wordList.zip(updatedCosines )
526+     wordList.zip(cosVec )
531527      .toSeq
532528      .sortBy(-  _._2)
533529      .take(num +  1 )
@@ -548,6 +544,23 @@ class Word2VecModel private[spark] (
548544@ Experimental 
549545object  Word2VecModel  extends  Loader [Word2VecModel ] {
550546
547+   private  def  buildWordIndex (model : Map [String , Array [Float ]]):  Map [String , Int ] =  {
548+     model.keys.zipWithIndex.toMap
549+   }
550+ 
551+   private  def  buildWordVectors (model : Map [String , Array [Float ]]):  Array [Float ] =  {
552+     require(model.nonEmpty, " Word2VecMap should be non-empty" 
553+     val  (vectorSize, numWords) =  (model.head._2.size, model.size)
554+     val  wordList  =  model.keys.toArray
555+     val  wordVectors  =  new  Array [Float ](vectorSize *  numWords)
556+     var  i  =  0 
557+     while  (i <  numWords) {
558+       Array .copy(model(wordList(i)), 0 , wordVectors, i *  vectorSize, vectorSize)
559+       i +=  1 
560+     }
561+     wordVectors
562+   }
563+ 
551564  private  object  SaveLoadV1_0  {
552565
553566    val  formatVersionV1_0  =  " 1.0" 
0 commit comments