@@ -429,15 +429,20 @@ class Word2Vec extends Serializable with Logging {
429429 */
430430@ Experimental
431431class Word2VecModel private [mllib] (
432- private val model : Map [String , Array [Float ]]) extends Serializable with Saveable {
433-
434- private val numDim = model.head._2.size
435- private val numWords = model.size
436- private val flatVec = model.toSeq.flatMap { case (w, v) =>
437- v.map(_.toDouble)}.toArray
438- private val wordVecMat = new DenseMatrix (numWords, numDim, flatVec, isTransposed= true )
439- private val wordVecNorms = model.map { case (word, vec) =>
440- blas.snrm2(numDim, vec, 1 )}.toArray
432+ model : Map [String , Array [Float ]]) extends Serializable with Saveable {
433+
434+ val indexedModel = model.keys.zip(0 until model.size).toMap
435+
436+ private val (wordVectors, wordVecNorms) = {
437+ val numDim = model.head._2.size
438+ val numWords = indexedModel.size
439+ val flatVec = model.toSeq.flatMap { case (w, v) =>
440+ v.map(_.toDouble)}.toArray
441+ val wordVectors = new DenseMatrix (numWords, numDim, flatVec, isTransposed= true )
442+ val wordVecNorms = model.map { case (word, vec) =>
443+ blas.snrm2(numDim, vec, 1 )}.toArray
444+ (wordVectors, wordVecNorms)
445+ }
441446
442447 private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
443448 require(v1.length == v2.length, " Vectors should have the same length" )
@@ -451,7 +456,7 @@ class Word2VecModel private[mllib] (
451456 override protected def formatVersion = " 1.0"
452457
453458 def save (sc : SparkContext , path : String ): Unit = {
454- Word2VecModel .SaveLoadV1_0 .save(sc, path, model )
459+ Word2VecModel .SaveLoadV1_0 .save(sc, path, getVectors )
455460 }
456461
457462 /**
@@ -488,16 +493,15 @@ class Word2VecModel private[mllib] (
488493 def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
489494 require(num > 0 , " Number of similar words should > 0" )
490495
491- val fVector = vector.toArray
492-
493- val cosineVec = new DenseVector (Array .fill[Double ](numWords)(0 ))
494- BLAS .gemv(1.0 , wordVecMat, vector.asInstanceOf [DenseVector ], 0.0 , cosineVec)
496+ val numWords = wordVectors.numRows
497+ val cosineVec = Vectors .zeros(numWords).asInstanceOf [DenseVector ]
498+ BLAS .gemv(1.0 , wordVectors, vector.asInstanceOf [DenseVector ], 0.0 , cosineVec)
495499
496500 // Need not divide with the norm of the given vector since it is constant.
497- val updatedCosines = model.zipWithIndex. map { case (vec , ind) =>
501+ val updatedCosines = indexedModel. map { case (_ , ind) =>
498502 cosineVec(ind) / wordVecNorms(ind) }
499503
500- model .keys.zip(updatedCosines)
504+ indexedModel .keys.zip(updatedCosines)
501505 .toSeq
502506 .sortBy(- _._2)
503507 .take(num + 1 )
@@ -509,7 +513,11 @@ class Word2VecModel private[mllib] (
509513 * Returns a map of words to their vector representations.
510514 */
511515 def getVectors : Map [String , Array [Float ]] = {
512- model
516+ val numDim = wordVectors.numCols
517+ indexedModel.map { case (word, ind) =>
518+ val startInd = numDim * ind
519+ val endInd = startInd + numDim
520+ (word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
513521 }
514522}
515523
0 commit comments