Skip to content

Commit 64575b0

Browse files
committed
Save indexedmap and a wordvecmat instead of matrix
1 parent fbe0108 commit 64575b0

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -429,15 +429,20 @@ class Word2Vec extends Serializable with Logging {
429429
*/
430430
@Experimental
431431
class Word2VecModel private[mllib] (
432-
private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
433-
434-
private val numDim = model.head._2.size
435-
private val numWords = model.size
436-
private val flatVec = model.toSeq.flatMap { case(w, v) =>
437-
v.map(_.toDouble)}.toArray
438-
private val wordVecMat = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
439-
private val wordVecNorms = model.map { case (word, vec) =>
440-
blas.snrm2(numDim, vec, 1)}.toArray
432+
model: Map[String, Array[Float]]) extends Serializable with Saveable {
433+
434+
val indexedModel = model.keys.zip(0 until model.size).toMap
435+
436+
private val (wordVectors, wordVecNorms) = {
437+
val numDim = model.head._2.size
438+
val numWords = indexedModel.size
439+
val flatVec = model.toSeq.flatMap { case(w, v) =>
440+
v.map(_.toDouble)}.toArray
441+
val wordVectors = new DenseMatrix(numWords, numDim, flatVec, isTransposed=true)
442+
val wordVecNorms = model.map { case (word, vec) =>
443+
blas.snrm2(numDim, vec, 1)}.toArray
444+
(wordVectors, wordVecNorms)
445+
}
441446

442447
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
443448
require(v1.length == v2.length, "Vectors should have the same length")
@@ -451,7 +456,7 @@ class Word2VecModel private[mllib] (
451456
override protected def formatVersion = "1.0"
452457

453458
def save(sc: SparkContext, path: String): Unit = {
454-
Word2VecModel.SaveLoadV1_0.save(sc, path, model)
459+
Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors)
455460
}
456461

457462
/**
@@ -488,16 +493,15 @@ class Word2VecModel private[mllib] (
488493
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
489494
require(num > 0, "Number of similar words should > 0")
490495

491-
val fVector = vector.toArray
492-
493-
val cosineVec = new DenseVector(Array.fill[Double](numWords)(0))
494-
BLAS.gemv(1.0, wordVecMat, vector.asInstanceOf[DenseVector], 0.0, cosineVec)
496+
val numWords = wordVectors.numRows
497+
val cosineVec = Vectors.zeros(numWords).asInstanceOf[DenseVector]
498+
BLAS.gemv(1.0, wordVectors, vector.asInstanceOf[DenseVector], 0.0, cosineVec)
495499

496500
// Need not divide with the norm of the given vector since it is constant.
497-
val updatedCosines = model.zipWithIndex.map { case (vec, ind) =>
501+
val updatedCosines = indexedModel.map { case (_, ind) =>
498502
cosineVec(ind) / wordVecNorms(ind) }
499503

500-
model.keys.zip(updatedCosines)
504+
indexedModel.keys.zip(updatedCosines)
501505
.toSeq
502506
.sortBy(- _._2)
503507
.take(num + 1)
@@ -509,7 +513,11 @@ class Word2VecModel private[mllib] (
509513
* Returns a map of words to their vector representations.
510514
*/
511515
def getVectors: Map[String, Array[Float]] = {
512-
model
516+
val numDim = wordVectors.numCols
517+
indexedModel.map { case (word, ind) =>
518+
val startInd = numDim * ind
519+
val endInd = startInd + numDim
520+
(word, wordVectors.values.slice(startInd, endInd).map(_.toFloat)) }
513521
}
514522
}
515523

0 commit comments

Comments
 (0)