Skip to content

Commit a400ab5

Browse files
MechCoderjkbradley
authored andcommitted
[SPARK-7045] [MLLIB] Avoid intermediate representation when creating model
Word2Vec used to convert from an Array[Float] representation to a Map[String, Array[Float]] and then back to an Array[Float] through Word2VecModel. This prevents this conversion while still supporting the older method of supplying a Map. Author: MechCoder <[email protected]> Closes #5748 from MechCoder/spark-7045 and squashes the following commits: e308913 [MechCoder] move docs 5703116 [MechCoder] minor fa04313 [MechCoder] style fixes b1d61c4 [MechCoder] better errors and tests 3b32c8c [MechCoder] [SPARK-7045] Avoid intermediate representation when creating model
1 parent 64135cb commit a400ab5

File tree

2 files changed

+55
-36
lines changed

2 files changed

+55
-36
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging {
403403
}
404404
newSentences.unpersist()
405405

406-
val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
407-
var i = 0
408-
while (i < vocabSize) {
409-
val word = bcVocab.value(i).word
410-
val vector = new Array[Float](vectorSize)
411-
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
412-
word2VecMap += word -> vector
413-
i += 1
414-
}
415-
416-
new Word2VecModel(word2VecMap.toMap)
406+
val wordArray = vocab.map(_.word)
407+
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
417408
}
418409

419410
/**
@@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging {
429420
/**
430421
* :: Experimental ::
431422
* Word2Vec model
423+
* @param wordIndex maps each word to an index, which can retrieve the corresponding
424+
* vector from wordVectors
425+
* @param wordVectors array of length numWords * vectorSize, vector corresponding
426+
* to the word mapped with index i can be retrieved by the slice
427+
* (i * vectorSize, i * vectorSize + vectorSize)
432428
*/
433429
@Experimental
434-
class Word2VecModel private[spark] (
435-
model: Map[String, Array[Float]]) extends Serializable with Saveable {
436-
437-
// wordList: Ordered list of words obtained from model.
438-
private val wordList: Array[String] = model.keys.toArray
439-
440-
// wordIndex: Maps each word to an index, which can retrieve the corresponding
441-
// vector from wordVectors (see below).
442-
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
430+
class Word2VecModel private[mllib] (
431+
private val wordIndex: Map[String, Int],
432+
private val wordVectors: Array[Float]) extends Serializable with Saveable {
443433

444-
// vectorSize: Dimension of each word's vector.
445-
private val vectorSize = model.head._2.size
446434
private val numWords = wordIndex.size
435+
// vectorSize: Dimension of each word's vector.
436+
private val vectorSize = wordVectors.length / numWords
437+
438+
// wordList: Ordered list of words obtained from wordIndex.
439+
private val wordList: Array[String] = {
440+
val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
441+
wl.toArray
442+
}
447443

448-
// wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
449-
// mapped with index i can be retrieved by the slice
450-
// (ind * vectorSize, ind * vectorSize + vectorSize)
451444
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
452445
// of the wordVector.
453-
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
454-
val wordVectors = new Array[Float](vectorSize * numWords)
446+
private val wordVecNorms: Array[Double] = {
455447
val wordVecNorms = new Array[Double](numWords)
456448
var i = 0
457449
while (i < numWords) {
458-
val vec = model.get(wordList(i)).get
459-
Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
450+
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
460451
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
461452
i += 1
462453
}
463-
(wordVectors, wordVecNorms)
454+
wordVecNorms
455+
}
456+
457+
def this(model: Map[String, Array[Float]]) = {
458+
this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
464459
}
465460

466461
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
@@ -484,8 +479,9 @@ class Word2VecModel private[spark] (
484479
* @return vector representation of word
485480
*/
486481
def transform(word: String): Vector = {
487-
model.get(word) match {
488-
case Some(vec) =>
482+
wordIndex.get(word) match {
483+
case Some(ind) =>
484+
val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
489485
Vectors.dense(vec.map(_.toDouble))
490486
case None =>
491487
throw new IllegalStateException(s"$word not in vocabulary")
@@ -511,7 +507,7 @@ class Word2VecModel private[spark] (
511507
*/
512508
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
513509
require(num > 0, "Number of similar words should > 0")
514-
510+
// TODO: optimize top-k
515511
val fVector = vector.toArray.map(_.toFloat)
516512
val cosineVec = Array.fill[Float](numWords)(0)
517513
val alpha: Float = 1
@@ -521,13 +517,13 @@ class Word2VecModel private[spark] (
521517
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
522518

523519
// Need not divide with the norm of the given vector since it is constant.
524-
val updatedCosines = new Array[Double](numWords)
520+
val cosVec = cosineVec.map(_.toDouble)
525521
var ind = 0
526522
while (ind < numWords) {
527-
updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
523+
cosVec(ind) /= wordVecNorms(ind)
528524
ind += 1
529525
}
530-
wordList.zip(updatedCosines)
526+
wordList.zip(cosVec)
531527
.toSeq
532528
.sortBy(- _._2)
533529
.take(num + 1)
@@ -548,6 +544,23 @@ class Word2VecModel private[spark] (
548544
@Experimental
549545
object Word2VecModel extends Loader[Word2VecModel] {
550546

547+
private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
548+
model.keys.zipWithIndex.toMap
549+
}
550+
551+
private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
552+
require(model.nonEmpty, "Word2VecMap should be non-empty")
553+
val (vectorSize, numWords) = (model.head._2.size, model.size)
554+
val wordList = model.keys.toArray
555+
val wordVectors = new Array[Float](vectorSize * numWords)
556+
var i = 0
557+
while (i < numWords) {
558+
Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
559+
i += 1
560+
}
561+
wordVectors
562+
}
563+
551564
private object SaveLoadV1_0 {
552565

553566
val formatVersionV1_0 = "1.0"

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
3737
assert(syms.length == 2)
3838
assert(syms(0)._1 == "b")
3939
assert(syms(1)._1 == "c")
40+
41+
// Test that model built using Word2Vec, i.e wordVectors and wordIndec
42+
// and a Word2VecMap give the same values.
43+
val word2VecMap = model.getVectors
44+
val newModel = new Word2VecModel(word2VecMap)
45+
assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq))
4046
}
4147

4248
test("Word2VecModel") {

0 commit comments

Comments
 (0)