@@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
3131import org .apache .spark .HashPartitioner
3232import org .apache .spark .storage .StorageLevel
3333import org .apache .spark .mllib .rdd .RDDFunctions ._
34+
3435/**
3536 * Entry in vocabulary
3637 */
@@ -61,18 +62,15 @@ private case class VocabWord(
6162 * Distributed Representations of Words and Phrases and their Compositionality.
6263 * @param size vector dimension
6364 * @param startingAlpha initial learning rate
64- * @param window context words from [-window, window]
65- * @param minCount minimum frequncy to consider a vocabulary word
66- * @param parallelisum number of partitions to run Word2Vec
65+ * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
66+ * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6767 */
6868@ Experimental
6969class Word2Vec (
7070 val size : Int ,
7171 val startingAlpha : Double ,
72- val window : Int ,
73- val minCount : Int ,
74- val parallelism : Int = 1 ,
75- val numIterations : Int = 1 )
72+ val parallelism : Int = 1 ,
73+ val numIterations : Int = 1 )
7674 extends Serializable with Logging {
7775
7876 private val EXP_TABLE_SIZE = 1000
@@ -81,7 +79,13 @@ class Word2Vec(
8179 private val MAX_SENTENCE_LENGTH = 1000
8280 private val layer1Size = size
8381 private val modelPartitionNum = 100
84-
82+
83+ /** context words from [-window, window] */
84+ private val window = 5
85+
86+ /** minimum frequency to consider a vocabulary word */
87+ private val minCount = 5
88+
8589 private var trainWordsCount = 0
8690 private var vocabSize = 0
8791 private var vocab : Array [VocabWord ] = null
@@ -99,7 +103,7 @@ class Word2Vec(
99103 0 ))
100104 .filter(_.cn >= minCount)
101105 .collect()
102- .sortWith((a, b)=> a.cn > b.cn)
106+ .sortWith((a, b) => a.cn > b.cn)
103107
104108 vocabSize = vocab.length
105109 var a = 0
@@ -111,16 +115,12 @@ class Word2Vec(
111115 logInfo(" trainWordsCount = " + trainWordsCount)
112116 }
113117
114- private def learnVocabPerPartition (words: RDD [String ]) {
115-
116- }
117-
118- private def createExpTable (): Array [Double ] = {
119- val expTable = new Array [Double ](EXP_TABLE_SIZE )
118+ private def createExpTable (): Array [Float ] = {
119+ val expTable = new Array [Float ](EXP_TABLE_SIZE )
120120 var i = 0
121121 while (i < EXP_TABLE_SIZE ) {
122122 val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0 ) * MAX_EXP )
123- expTable(i) = tmp / (tmp + 1 )
123+ expTable(i) = ( tmp / (tmp + 1.0 )).toFloat
124124 i += 1
125125 }
126126 expTable
@@ -209,7 +209,7 @@ class Word2Vec(
209209 * @return a Word2VecModel
210210 */
211211
212- def fit [S <: Iterable [String ]](dataset: RDD [S ]): Word2VecModel = {
212+ def fit [S <: Iterable [String ]](dataset : RDD [S ]): Word2VecModel = {
213213
214214 val words = dataset.flatMap(x => x)
215215
@@ -223,39 +223,37 @@ class Word2Vec(
223223 val bcVocab = sc.broadcast(vocab)
224224 val bcVocabHash = sc.broadcast(vocabHash)
225225
226- val sentences : RDD [Array [Int ]] = words.mapPartitions {
227- iter => { new Iterator [Array [Int ]] {
228- def hasNext = iter.hasNext
229-
230- def next = {
231- var sentence = new ArrayBuffer [Int ]
232- var sentenceLength = 0
233- while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
234- val word = bcVocabHash.value.get(iter.next)
235- word match {
236- case Some (w) => {
237- sentence += w
238- sentenceLength += 1
239- }
240- case None =>
241- }
226+ val sentences : RDD [Array [Int ]] = words.mapPartitions { iter =>
227+ new Iterator [Array [Int ]] {
228+ def hasNext : Boolean = iter.hasNext
229+
230+ def next (): Array [Int ] = {
231+ var sentence = new ArrayBuffer [Int ]
232+ var sentenceLength = 0
233+ while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
234+ val word = bcVocabHash.value.get(iter.next())
235+ word match {
236+ case Some (w) =>
237+ sentence += w
238+ sentenceLength += 1
239+ case None =>
242240 }
243- sentence.toArray
244241 }
242+ sentence.toArray
245243 }
246244 }
247245 }
248246
249247 val newSentences = sentences.repartition(parallelism).cache()
250- var syn0Global
251- = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
252- var syn1Global = new Array [Double ](vocabSize * layer1Size)
248+ var syn0Global =
249+ Array .fill[Float ](vocabSize * layer1Size)((Random .nextFloat() - 0.5f ) / layer1Size)
250+ var syn1Global = new Array [Float ](vocabSize * layer1Size)
253251
254252 for (iter <- 1 to numIterations) {
255253 val (aggSyn0, aggSyn1, _, _) =
256- // TODO: broadcast temp instead of serializing it directly
254+ // TODO: broadcast temp instead of serializing it directly
257255 // or initialize the model in each executor
258- newSentences.treeAggregate((syn0Global.clone() , syn1Global.clone() , 0 , 0 ))(
256+ newSentences.treeAggregate((syn0Global, syn1Global, 0 , 0 ))(
259257 seqOp = (c, v) => (c, v) match {
260258 case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
261259 var lwc = lastWordCount
@@ -280,23 +278,23 @@ class Word2Vec(
280278 if (c >= 0 && c < sentence.size) {
281279 val lastWord = sentence(c)
282280 val l1 = lastWord * layer1Size
283- val neu1e = new Array [Double ](layer1Size)
281+ val neu1e = new Array [Float ](layer1Size)
284282 // Hierarchical softmax
285283 var d = 0
286284 while (d < bcVocab.value(word).codeLen) {
287285 val l2 = bcVocab.value(word).point(d) * layer1Size
288286 // Propagate hidden -> output
289- var f = blas.ddot (layer1Size, syn0, l1, 1 , syn1, l2, 1 )
287+ var f = blas.sdot (layer1Size, syn0, l1, 1 , syn1, l2, 1 )
290288 if (f > - MAX_EXP && f < MAX_EXP ) {
291289 val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
292290 f = expTable.value(ind)
293- val g = (1 - bcVocab.value(word).code(d) - f) * alpha
294- blas.daxpy (layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
295- blas.daxpy (layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
291+ val g = (( 1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
292+ blas.saxpy (layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
293+ blas.saxpy (layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
296294 }
297295 d += 1
298296 }
299- blas.daxpy (layer1Size, 1.0 , neu1e, 0 , 1 , syn0, l1, 1 )
297+ blas.saxpy (layer1Size, 1.0f , neu1e, 0 , 1 , syn0, l1, 1 )
300298 }
301299 }
302300 a += 1
@@ -308,24 +306,24 @@ class Word2Vec(
308306 combOp = (c1, c2) => (c1, c2) match {
309307 case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
310308 val n = syn0_1.length
311- val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
312- val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
313- blas.dscal (n, weight1, syn0_1, 1 )
314- blas.dscal (n, weight1, syn1_1, 1 )
315- blas.daxpy (n, weight2, syn0_2, 1 , syn0_1, 1 )
316- blas.daxpy (n, weight2, syn1_2, 1 , syn1_1, 1 )
309+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
310+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
311+ blas.sscal (n, weight1, syn0_1, 1 )
312+ blas.sscal (n, weight1, syn1_1, 1 )
313+ blas.saxpy (n, weight2, syn0_2, 1 , syn0_1, 1 )
314+ blas.saxpy (n, weight2, syn1_2, 1 , syn1_1, 1 )
317315 (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
318316 })
319317 syn0Global = aggSyn0
320318 syn1Global = aggSyn1
321319 }
322320 newSentences.unpersist()
323321
324- val wordMap = new Array [(String , Array [Double ])](vocabSize)
322+ val wordMap = new Array [(String , Array [Float ])](vocabSize)
325323 var i = 0
326324 while (i < vocabSize) {
327325 val word = bcVocab.value(i).word
328- val vector = new Array [Double ](layer1Size)
326+ val vector = new Array [Float ](layer1Size)
329327 Array .copy(syn0Global, i * layer1Size, vector, 0 , layer1Size)
330328 wordMap(i) = (word, vector)
331329 i += 1
@@ -341,15 +339,15 @@ class Word2Vec(
341339/**
342340* Word2Vec model
343341*/
344- class Word2VecModel (private val model : RDD [(String , Array [Double ])]) extends Serializable {
342+ class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
345343
346- private def cosineSimilarity (v1 : Array [Double ], v2 : Array [Double ]): Double = {
344+ private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
347345 require(v1.length == v2.length, " Vectors should have the same length" )
348346 val n = v1.length
349- val norm1 = blas.dnrm2 (n, v1, 1 )
350- val norm2 = blas.dnrm2 (n, v2, 1 )
347+ val norm1 = blas.snrm2 (n, v1, 1 )
348+ val norm2 = blas.snrm2 (n, v2, 1 )
351349 if (norm1 == 0 || norm2 == 0 ) return 0.0
352- blas.ddot (n, v1, 1 , v2,1 ) / norm1 / norm2
350+ blas.sdot (n, v1, 1 , v2,1 ) / norm1 / norm2
353351 }
354352
355353 /**
@@ -362,7 +360,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
362360 if (result.isEmpty) {
363361 throw new IllegalStateException (s " ${word} not in vocabulary " )
364362 }
365- else Vectors .dense(result(0 ))
363+ else Vectors .dense(result(0 ).map(_.toDouble) )
366364 }
367365
368366 /**
@@ -394,7 +392,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
394392 def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
395393 require(num > 0 , " Number of similar words should > 0" )
396394 val topK = model.map { case (w, vec) =>
397- (cosineSimilarity(vector.toArray, vec), w) }
395+ (cosineSimilarity(vector.toArray.map(_.toFloat) , vec), w) }
398396 .sortByKey(ascending = false )
399397 .take(num + 1 )
400398 .map(_.swap)
@@ -410,18 +408,16 @@ object Word2Vec{
410408 * @param input RDD of words
411409 * @param size vector dimension
412410 * @param startingAlpha initial learning rate
413- * @param window context words from [-window, window]
414- * @param minCount minimum frequncy to consider a vocabulary word
415- * @return Word2Vec model
416- */
411+ * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
412+ * @param numIterations number of iterations, should be smaller than or equal to parallelism
413+ * @return Word2Vec model
414+ */
417415 def train [S <: Iterable [String ]](
418416 input : RDD [S ],
419417 size : Int ,
420418 startingAlpha : Double ,
421- window : Int ,
422- minCount : Int ,
423419 parallelism : Int = 1 ,
424420 numIterations: Int = 1 ): Word2VecModel = {
425- new Word2Vec (size,startingAlpha, window, minCount, parallelism, numIterations).fit[S ](input)
421+ new Word2Vec (size,startingAlpha, parallelism, numIterations).fit[S ](input)
426422 }
427423}
0 commit comments